diff --git a/README.md b/README.md index 9fe1699..3ead826 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,8 @@ async def websocket_endpoint(websocket: WebSocket): | `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` | | `--diarization` | Enable speaker identification | `False` | -| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` | +| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` | +| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` | | `--no-vac` | Disable Voice Activity Controller | `False` | | `--no-vad` | Disable Voice Activity Detection | `False` | | `--warmup-file` | Audio file path for model warmup | `jfk.wav` | diff --git a/whisperlivekit/backend_support.py b/whisperlivekit/backend_support.py new file mode 100644 index 0000000..a64770a --- /dev/null +++ b/whisperlivekit/backend_support.py @@ -0,0 +1,41 @@ +import importlib.util +import logging +import platform + +logger = logging.getLogger(__name__) + + +def module_available(module_name): + """Return True if the given module can be imported.""" + return importlib.util.find_spec(module_name) is not None + + +def mlx_backend_available(warn_on_missing = False): + is_macos = platform.system() == "Darwin" + is_arm = platform.machine() == "arm64" + available = ( + is_macos + and is_arm + and module_available("mlx_whisper") + ) + if not available and warn_on_missing and is_macos and is_arm: + logger.warning( + "=" * 50 + + "\nMLX Whisper not found but you are on Apple Silicon. " + "Consider installing mlx-whisper for better performance: " + "`pip install mlx-whisper`\n" + + "=" * 50 + ) + return available + + +def faster_backend_available(warn_on_missing = False): + available = module_available("faster_whisper") + if not available and warn_on_missing and platform.system() != "Darwin": + logger.warning( + "=" * 50 + + "\nFaster-Whisper not found. Consider installing faster-whisper " + "for better performance: `pip install faster-whisper`\n" + + "=" * 50 + ) + return available diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 31df810..5811f7c 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -3,6 +3,7 @@ from whisperlivekit.simul_whisper import SimulStreamingASR from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from argparse import Namespace import sys +import logging def update_with_kwargs(_dict, kwargs): _dict.update({ @@ -10,6 +11,9 @@ def update_with_kwargs(_dict, kwargs): }) return _dict + +logger = logging.getLogger(__name__) + class TranscriptionEngine: _instance = None _initialized = False @@ -41,16 +45,18 @@ class TranscriptionEngine: "pcm_input": False, "disable_punctuation_split" : False, "diarization_backend": "sortformer", + "backend_policy": "simulstreaming", + "backend": "auto", } global_params = update_with_kwargs(global_params, kwargs) transcription_common_params = { - "backend": "simulstreaming", "warmup_file": None, "min_chunk_size": 0.5, "model_size": "tiny", "model_cache_dir": None, "model_dir": None, + "model_path": None, "lan": "auto", "direct_english_translation": False, } @@ -78,8 +84,9 @@ class TranscriptionEngine: use_onnx = kwargs.get('vac_onnx', False) self.vac_model = load_silero_vad(onnx=use_onnx) + backend_policy = self.args.backend_policy if self.args.transcription: - if self.args.backend == "simulstreaming": + if backend_policy == "simulstreaming": simulstreaming_params = { "disable_fast_encoder": False, "custom_alignment_heads": None, @@ -93,14 +100,19 @@ class TranscriptionEngine: "init_prompt": None, "static_init_prompt": None, "max_context_tokens": None, - "model_path": './base.pt', "preload_model_count": 1, } simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs) self.tokenizer = None self.asr = SimulStreamingASR( - **transcription_common_params, **simulstreaming_params + **transcription_common_params, + **simulstreaming_params, + backend=self.args.backend, + ) + logger.info( + "Using SimulStreaming policy with %s backend", + getattr(self.asr, "encoder_backend", "whisper"), ) else: @@ -112,7 +124,13 @@ class TranscriptionEngine: whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs) self.asr = backend_factory( - **transcription_common_params, **whisperstreaming_params + backend=self.args.backend, + **transcription_common_params, + **whisperstreaming_params, + ) + logger.info( + "Using LocalAgreement policy with %s backend", + getattr(self.asr, "backend_choice", self.asr.__class__.__name__), ) if self.args.diarization: @@ -133,7 +151,7 @@ class TranscriptionEngine: self.translation_model = None if self.args.target_language: - if self.args.lan == 'auto' and self.args.backend != "simulstreaming": + if self.args.lan == 'auto' and backend_policy != "simulstreaming": raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming') else: try: @@ -150,7 +168,7 @@ class TranscriptionEngine: def online_factory(args, asr): - if args.backend == "simulstreaming": + if args.backend_policy == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor online = SimulStreamingOnlineProcessor(asr) else: diff --git a/whisperlivekit/local_agreement/backends.py b/whisperlivekit/local_agreement/backends.py index 4fce26d..360df0e 100644 --- a/whisperlivekit/local_agreement/backends.py +++ b/whisperlivekit/local_agreement/backends.py @@ -6,6 +6,8 @@ import math from typing import List import numpy as np from whisperlivekit.timed_objects import ASRToken +from whisperlivekit.model_paths import resolve_model_path, model_path_and_type +from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe logger = logging.getLogger(__name__) class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, @@ -37,40 +39,60 @@ class ASRBase: raise NotImplementedError("must be implemented in the child class") -class WhisperTimestampedASR(ASRBase): - """Uses whisper_timestamped as the backend.""" +class WhisperASR(ASRBase): + """Uses WhisperLiveKit's built-in Whisper implementation.""" sep = " " def load_model(self, model_size=None, cache_dir=None, model_dir=None): - import whisper - import whisper_timestamped - from whisper_timestamped import transcribe_timestamped + from whisperlivekit.whisper import load_model as load_model - self.transcribe_timestamped = transcribe_timestamped if model_dir is not None: - logger.debug("ignoring model_dir, not implemented") - return whisper.load_model(model_size, download_root=cache_dir) + resolved_path = resolve_model_path(model_dir) + if resolved_path.is_dir(): + pytorch_path, _, _ = model_path_and_type(resolved_path) + if pytorch_path is None: + raise FileNotFoundError( + f"No supported PyTorch checkpoint found under {resolved_path}" + ) + resolved_path = pytorch_path + logger.debug(f"Loading Whisper model from custom path {resolved_path}") + return load_model(str(resolved_path)) + + if model_size is None: + raise ValueError("Either model_size or model_dir must be set for WhisperASR") + + return load_model(model_size, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): - result = self.transcribe_timestamped( + options = dict(self.transcribe_kargs) + options.pop("vad", None) + options.pop("vad_filter", None) + language = self.original_language if self.original_language else None + + result = whisper_transcribe( self.model, audio, - language=self.original_language, + language=language, initial_prompt=init_prompt, - verbose=None, condition_on_previous_text=True, - **self.transcribe_kargs, + word_timestamps=True, + **options, ) return result def ts_words(self, r) -> List[ASRToken]: """ - Converts the whisper_timestamped result to a list of ASRToken objects. + Converts the Whisper result to a list of ASRToken objects. """ tokens = [] for segment in r["segments"]: for word in segment["words"]: - token = ASRToken(word["start"], word["end"], word["text"]) + token = ASRToken( + word["start"], + word["end"], + word["word"], + probability=word.get("probability"), + ) tokens.append(token) return tokens @@ -78,7 +100,7 @@ class WhisperTimestampedASR(ASRBase): return [segment["end"] for segment in res["segments"]] def use_vad(self): - self.transcribe_kargs["vad"] = True + logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.") class FasterWhisperASR(ASRBase): """Uses faster-whisper as the backend.""" @@ -88,9 +110,10 @@ class FasterWhisperASR(ASRBase): from faster_whisper import WhisperModel if model_dir is not None: - logger.debug(f"Loading whisper model from model_dir {model_dir}. " + resolved_path = resolve_model_path(model_dir) + logger.debug(f"Loading faster-whisper model from {resolved_path}. " f"model_size and cache_dir parameters are not used.") - model_size_or_path = model_dir + model_size_or_path = str(resolved_path) elif model_size is not None: model_size_or_path = model_size else: @@ -146,8 +169,9 @@ class MLXWhisper(ASRBase): import mlx.core as mx if model_dir is not None: - logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.") - model_size_or_path = model_dir + resolved_path = resolve_model_path(model_dir) + logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.") + model_size_or_path = str(resolved_path) elif model_size is not None: model_size_or_path = self.translate_model_name(model_size) logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.") @@ -272,4 +296,4 @@ class OpenaiApiASR(ASRBase): return transcript def use_vad(self): - self.use_vad_opt = True \ No newline at end of file + self.use_vad_opt = True diff --git a/whisperlivekit/local_agreement/whisper_online.py b/whisperlivekit/local_agreement/whisper_online.py index aac85b5..87b43ff 100644 --- a/whisperlivekit/local_agreement/whisper_online.py +++ b/whisperlivekit/local_agreement/whisper_online.py @@ -5,13 +5,18 @@ import librosa from functools import lru_cache import time import logging -from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR +import platform +from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR from whisperlivekit.warmup import warmup_asr +from whisperlivekit.model_paths import resolve_model_path, model_path_and_type +from whisperlivekit.backend_support import ( + mlx_backend_available, + faster_backend_available, +) logger = logging.getLogger(__name__) - WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split( "," ) @@ -70,6 +75,7 @@ def backend_factory( model_size, model_cache_dir, model_dir, + model_path, direct_english_translation, buffer_trimming, buffer_trimming_sec, @@ -77,27 +83,56 @@ def backend_factory( warmup_file=None, min_chunk_size=None, ): - backend = backend - if backend == "openai-api": - logger.debug("Using OpenAI API.") - asr = OpenaiApiASR(lan=lan) - else: - if backend == "faster-whisper": - asr_cls = FasterWhisperASR - elif backend == "mlx-whisper": - asr_cls = MLXWhisper - else: - asr_cls = WhisperTimestampedASR + backend_choice = backend + custom_reference = model_path or model_dir + resolved_root = None + pytorch_checkpoint = None + has_mlx_weights = False + has_fw_weights = False - # Only for FasterWhisperASR and WhisperTimestampedASR + if custom_reference: + resolved_root = resolve_model_path(custom_reference) + if resolved_root.is_dir(): + pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root) + else: + pytorch_checkpoint = resolved_root + + if backend_choice == "openai-api": + logger.debug("Using OpenAI API.") + asr = OpenaiApiASR(lan=lan) + else: + backend_choice = _normalize_backend_choice( + backend_choice, + resolved_root, + has_mlx_weights, + has_fw_weights, + ) + + if backend_choice == "faster-whisper": + asr_cls = FasterWhisperASR + if resolved_root is not None and not resolved_root.is_dir(): + raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.") + model_override = str(resolved_root) if resolved_root is not None else None + elif backend_choice == "mlx-whisper": + asr_cls = MLXWhisper + if resolved_root is not None and not resolved_root.is_dir(): + raise ValueError("MLX Whisper backend expects a directory containing MLX weights.") + model_override = str(resolved_root) if resolved_root is not None else None + else: + asr_cls = WhisperASR + model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None + if custom_reference and model_override is None: + raise FileNotFoundError( + f"No PyTorch checkpoint found under {resolved_root or custom_reference}" + ) t = time.time() - logger.info(f"Loading Whisper {model_size} model for language {lan}...") + logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...") asr = asr_cls( model_size=model_size, lan=lan, cache_dir=model_cache_dir, - model_dir=model_dir, + model_dir=model_override, ) e = time.time() logger.info(f"done. It took {round(e-t,2)} seconds.") @@ -119,4 +154,46 @@ def backend_factory( asr.tokenizer = tokenizer asr.buffer_trimming = buffer_trimming asr.buffer_trimming_sec = buffer_trimming_sec - return asr \ No newline at end of file + asr.backend_choice = backend_choice + return asr + + +def _normalize_backend_choice( + preferred_backend, + resolved_root, + has_mlx_weights, + has_fw_weights, +): + backend_choice = preferred_backend + + if backend_choice == "auto": + if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights): + return "mlx-whisper" + if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights): + return "faster-whisper" + return "whisper" + + if backend_choice == "mlx-whisper": + if not mlx_backend_available(): + raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.") + if resolved_root is not None and not has_mlx_weights: + raise FileNotFoundError( + f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}" + ) + if platform.system() != "Darwin": + logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.") + return backend_choice + + if backend_choice == "faster-whisper": + if not faster_backend_available(): + raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.") + if resolved_root is not None and not has_fw_weights: + raise FileNotFoundError( + f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}" + ) + return backend_choice + + if backend_choice == "whisper": + return backend_choice + + raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.") diff --git a/whisperlivekit/model_paths.py b/whisperlivekit/model_paths.py new file mode 100644 index 0000000..2e3eb0a --- /dev/null +++ b/whisperlivekit/model_paths.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Optional, Tuple, Union + + +def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]: + """ + Inspect the provided path and determine which model formats are available. + + Returns: + pytorch_path: Path to a PyTorch checkpoint (if present). + compatible_whisper_mlx: True if MLX weights exist in this folder. + compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist. + """ + path = Path(model_path) + + compatible_whisper_mlx = False + compatible_faster_whisper = False + pytorch_path: Optional[Path] = None + + if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]: + pytorch_path = path + return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper + + if path.is_dir(): + for file in path.iterdir(): + if not file.is_file(): + continue + + filename = file.name.lower() + suffix = file.suffix.lower() + + if filename in {"weights.npz", "weights.safetensors"}: + compatible_whisper_mlx = True + elif filename in {"model.bin", "encoder.bin", "decoder.bin"}: + compatible_faster_whisper = True + elif suffix in {".pt", ".safetensors"}: + pytorch_path = file + elif filename == "pytorch_model.bin": + pytorch_path = file + + if pytorch_path is None: + fallback = path / "pytorch_model.bin" + if fallback.exists(): + pytorch_path = fallback + + return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper + + +def resolve_model_path(model_path: Union[str, Path]) -> Path: + """ + Return a local path for the provided model reference. + + If the path does not exist locally, it is treated as a Hugging Face repo id + and downloaded via snapshot_download. + """ + path = Path(model_path).expanduser() + if path.exists(): + return path + + try: + from huggingface_hub import snapshot_download + except ImportError as exc: # pragma: no cover - optional dependency guard + raise FileNotFoundError( + f"Model path '{model_path}' does not exist locally and huggingface_hub " + "is not installed to download it." + ) from exc + + downloaded_path = Path(snapshot_download(repo_id=str(model_path))) + return downloaded_path diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index cbaea22..624c59c 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -129,11 +129,18 @@ def parse_args(): ) parser.add_argument( - "--backend", + "--backend-policy", type=str, default="simulstreaming", - choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"], - help="Load only this backend for Whisper processing.", + choices=["1", "2", "simulstreaming", "localagreement"], + help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.", + ) + parser.add_argument( + "--backend", + type=str, + default="auto", + choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"], + help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.", ) parser.add_argument( "--no-vac", @@ -316,5 +323,10 @@ def parse_args(): args.vad = not args.no_vad delattr(args, 'no_transcription') delattr(args, 'no_vad') + + if args.backend_policy == "1": + args.backend_policy = "simulstreaming" + elif args.backend_policy == "2": + args.backend_policy = "localagreement" return args diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 8342c32..3416898 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -1,7 +1,7 @@ import sys import numpy as np import logging -from typing import List, Tuple, Optional, Union +from typing import List, Tuple, Optional import platform from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker from whisperlivekit.warmup import load_file @@ -10,6 +10,11 @@ from whisperlivekit.whisper.audio import TOKENS_PER_SECOND import os import gc from pathlib import Path +from whisperlivekit.model_paths import model_path_and_type, resolve_model_path +from whisperlivekit.backend_support import ( + mlx_backend_available, + faster_backend_available, +) import torch from whisperlivekit.simul_whisper.config import AlignAttConfig @@ -18,65 +23,16 @@ from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper logger = logging.getLogger(__name__) -try: - from .mlx_encoder import mlx_model_mapping, load_mlx_encoder - HAS_MLX_WHISPER = True -except ImportError: - if platform.system() == "Darwin" and platform.machine() == "arm64": - print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: `pip install mlx-whisper`\n{"="*50}""") - HAS_MLX_WHISPER = False +HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) if HAS_MLX_WHISPER: - HAS_FASTER_WHISPER = False + from .mlx_encoder import mlx_model_mapping, load_mlx_encoder else: - try: - from faster_whisper import WhisperModel - HAS_FASTER_WHISPER = True - except ImportError: - if platform.system() != "Darwin": - print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""") - HAS_FASTER_WHISPER = False - -def model_path_and_type(model_path: Union[str, Path]): - path = Path(model_path) - - compatible_whisper_mlx = False - compatible_faster_whisper = False - pytorch_path = None - if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']: - pytorch_path = path - elif path.is_dir(): - for file in path.iterdir(): - if file.is_file(): - if file.name in ['weights.npz', "weights.safetensors"]: - compatible_whisper_mlx = True - elif file.suffix.lower() == '.bin': - compatible_faster_whisper = True - elif file.suffix.lower() == '.pt': - pytorch_path = file - elif file.suffix.lower() == '.safetensors': - pytorch_path = file - if pytorch_path is None: - if (path / Path("pytorch_model.bin")).exists(): - pytorch_path = path / Path("pytorch_model.bin") - return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper - - -def resolve_model_path(model_path: str) -> Path: - path = Path(model_path) - if path.exists(): - return path - - try: - from huggingface_hub import snapshot_download - except ImportError as exc: - raise FileNotFoundError( - f"Model path '{model_path}' does not exist locally and huggingface_hub " - "is not installed to download it." - ) from exc - - downloaded_path = Path(snapshot_download(repo_id=model_path)) - return downloaded_path - + mlx_model_mapping = {} +HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) +if HAS_FASTER_WHISPER: + from faster_whisper import WhisperModel +else: + WhisperModel = None class SimulStreamingOnlineProcessor: SAMPLING_RATE = 16000 @@ -194,13 +150,22 @@ class SimulStreamingASR(): self.decoder_type = 'greedy' if self.beams == 1 else 'beam' self.fast_encoder = False + self._resolved_model_path = None + self.encoder_backend = "whisper" + preferred_backend = getattr(self, "backend", "auto") self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True if self.model_path: resolved_model_path = resolve_model_path(self.model_path) + self._resolved_model_path = resolved_model_path self.model_path = str(resolved_model_path) self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path) - self.model_name = self.pytorch_path.stem - is_multilingual = not self.model_path.endswith(".en") + if self.pytorch_path: + self.model_name = self.pytorch_path.stem + else: + self.model_name = Path(self.model_path).stem + raise FileNotFoundError( + f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" + ) elif self.model_size is not None: model_mapping = { 'tiny': './tiny.pt', @@ -217,7 +182,19 @@ class SimulStreamingASR(): 'large': './large-v3.pt' } self.model_name = self.model_size - is_multilingual = not self.model_name.endswith(".en") + else: + raise ValueError("Either model_size or model_path must be specified for SimulStreaming.") + + is_multilingual = not self.model_name.endswith(".en") + + self.encoder_backend = self._resolve_encoder_backend( + preferred_backend, + compatible_whisper_mlx, + compatible_faster_whisper, + ) + self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper") + if self.encoder_backend == "whisper": + self.disable_fast_encoder = True self.cfg = AlignAttConfig( tokenizer_is_multilingual= is_multilingual, @@ -246,32 +223,72 @@ class SimulStreamingASR(): self.mlx_encoder, self.fw_encoder = None, None - if not self.disable_fast_encoder: - if HAS_MLX_WHISPER: - print('Simulstreaming will use MLX whisper to increase encoding speed.') - if self.model_path and compatible_whisper_mlx: - mlx_model = self.model_path - else: - mlx_model = mlx_model_mapping.get(self.model_name) - if mlx_model: - self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) - self.fast_encoder = True - elif HAS_FASTER_WHISPER and compatible_faster_whisper: - print('Simulstreaming will use Faster Whisper for the encoder.') - if self.model_path and compatible_faster_whisper: - fw_model = self.model_path - else: - fw_model = self.model_name - self.fw_encoder = WhisperModel( - fw_model, - device='auto', - compute_type='auto', + if self.encoder_backend == "mlx-whisper": + print('Simulstreaming will use MLX whisper to increase encoding speed.') + if self._resolved_model_path is not None: + mlx_model = str(self._resolved_model_path) + else: + mlx_model = mlx_model_mapping.get(self.model_name) + if not mlx_model: + raise FileNotFoundError( + f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'." ) - self.fast_encoder = True + self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) + elif self.encoder_backend == "faster-whisper": + print('Simulstreaming will use Faster Whisper for the encoder.') + if self._resolved_model_path is not None: + fw_model = str(self._resolved_model_path) + else: + fw_model = self.model_name + self.fw_encoder = WhisperModel( + fw_model, + device='auto', + compute_type='auto', + ) self.models = [self.load_model() for i in range(self.preload_model_count)] + def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper): + choice = preferred_backend or "auto" + if self.disable_fast_encoder: + return "whisper" + if choice == "whisper": + return "whisper" + if choice == "mlx-whisper": + if not self._can_use_mlx(compatible_whisper_mlx): + raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.") + return "mlx-whisper" + if choice == "faster-whisper": + if not self._can_use_faster(compatible_faster_whisper): + raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.") + return "faster-whisper" + if choice == "openai-api": + raise ValueError("openai-api backend is only supported with the LocalAgreement policy.") + # auto mode + if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx): + return "mlx-whisper" + if self._can_use_faster(compatible_faster_whisper): + return "faster-whisper" + return "whisper" + + def _has_custom_model_path(self): + return self._resolved_model_path is not None + + def _can_use_mlx(self, compatible_whisper_mlx): + if not HAS_MLX_WHISPER: + return False + if self._has_custom_model_path(): + return compatible_whisper_mlx + return self.model_name in mlx_model_mapping + + def _can_use_faster(self, compatible_faster_whisper): + if not HAS_FASTER_WHISPER: + return False + if self._has_custom_model_path(): + return compatible_faster_whisper + return True + def load_model(self): whisper_model = load_model( name=self.pytorch_path if self.pytorch_path else self.model_name, diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 697abec..4ba900d 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -17,6 +17,10 @@ from .eow_detection import fire_at_boundary, load_cif import os from time import time from .token_buffer import TokenBuffer +from whisperlivekit.backend_support import ( + mlx_backend_available, + faster_backend_available, +) import numpy as np from ..timed_objects import PUNCTUATION_MARKS @@ -26,21 +30,18 @@ DEC_PAD = 50257 logger = logging.getLogger(__name__) -try: +HAS_MLX_WHISPER = False +HAS_FASTER_WHISPER = False + +if mlx_backend_available(): from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim HAS_MLX_WHISPER = True -except ImportError: - HAS_MLX_WHISPER = False -if HAS_MLX_WHISPER: - HAS_FASTER_WHISPER = False -else: - try: - from faster_whisper.audio import pad_or_trim as fw_pad_or_trim - from faster_whisper.feature_extractor import FeatureExtractor - HAS_FASTER_WHISPER = True - except ImportError: - HAS_FASTER_WHISPER = False + +if faster_backend_available(): + from faster_whisper.audio import pad_or_trim as fw_pad_or_trim + from faster_whisper.feature_extractor import FeatureExtractor + HAS_FASTER_WHISPER = True class PaddedAlignAttWhisper: def __init__(