mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
Refactor backend handling
This commit is contained in:
@@ -144,7 +144,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
|
||||
41
whisperlivekit/backend_support.py
Normal file
41
whisperlivekit/backend_support.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def module_available(module_name):
|
||||
"""Return True if the given module can be imported."""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def mlx_backend_available(warn_on_missing = False):
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_arm = platform.machine() == "arm64"
|
||||
available = (
|
||||
is_macos
|
||||
and is_arm
|
||||
and module_available("mlx_whisper")
|
||||
)
|
||||
if not available and warn_on_missing and is_macos and is_arm:
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nMLX Whisper not found but you are on Apple Silicon. "
|
||||
"Consider installing mlx-whisper for better performance: "
|
||||
"`pip install mlx-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
|
||||
|
||||
def faster_backend_available(warn_on_missing = False):
|
||||
available = module_available("faster_whisper")
|
||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
|
||||
"for better performance: `pip install faster-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
@@ -3,6 +3,7 @@ from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from argparse import Namespace
|
||||
import sys
|
||||
import logging
|
||||
|
||||
def update_with_kwargs(_dict, kwargs):
|
||||
_dict.update({
|
||||
@@ -10,6 +11,9 @@ def update_with_kwargs(_dict, kwargs):
|
||||
})
|
||||
return _dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -41,16 +45,18 @@ class TranscriptionEngine:
|
||||
"pcm_input": False,
|
||||
"disable_punctuation_split" : False,
|
||||
"diarization_backend": "sortformer",
|
||||
"backend_policy": "simulstreaming",
|
||||
"backend": "auto",
|
||||
}
|
||||
global_params = update_with_kwargs(global_params, kwargs)
|
||||
|
||||
transcription_common_params = {
|
||||
"backend": "simulstreaming",
|
||||
"warmup_file": None,
|
||||
"min_chunk_size": 0.5,
|
||||
"model_size": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"model_path": None,
|
||||
"lan": "auto",
|
||||
"direct_english_translation": False,
|
||||
}
|
||||
@@ -78,8 +84,9 @@ class TranscriptionEngine:
|
||||
use_onnx = kwargs.get('vac_onnx', False)
|
||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||
|
||||
backend_policy = self.args.backend_policy
|
||||
if self.args.transcription:
|
||||
if self.args.backend == "simulstreaming":
|
||||
if backend_policy == "simulstreaming":
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": False,
|
||||
"custom_alignment_heads": None,
|
||||
@@ -93,14 +100,19 @@ class TranscriptionEngine:
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"preload_model_count": 1,
|
||||
}
|
||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||
|
||||
self.tokenizer = None
|
||||
self.asr = SimulStreamingASR(
|
||||
**transcription_common_params, **simulstreaming_params
|
||||
**transcription_common_params,
|
||||
**simulstreaming_params,
|
||||
backend=self.args.backend,
|
||||
)
|
||||
logger.info(
|
||||
"Using SimulStreaming policy with %s backend",
|
||||
getattr(self.asr, "encoder_backend", "whisper"),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -112,7 +124,13 @@ class TranscriptionEngine:
|
||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||
|
||||
self.asr = backend_factory(
|
||||
**transcription_common_params, **whisperstreaming_params
|
||||
backend=self.args.backend,
|
||||
**transcription_common_params,
|
||||
**whisperstreaming_params,
|
||||
)
|
||||
logger.info(
|
||||
"Using LocalAgreement policy with %s backend",
|
||||
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||
)
|
||||
|
||||
if self.args.diarization:
|
||||
@@ -133,7 +151,7 @@ class TranscriptionEngine:
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
try:
|
||||
@@ -150,7 +168,7 @@ class TranscriptionEngine:
|
||||
|
||||
|
||||
def online_factory(args, asr):
|
||||
if args.backend == "simulstreaming":
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(asr)
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,8 @@ import math
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
logger = logging.getLogger(__name__)
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
@@ -37,40 +39,60 @@ class ASRBase:
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperTimestampedASR(ASRBase):
|
||||
"""Uses whisper_timestamped as the backend."""
|
||||
class WhisperASR(ASRBase):
|
||||
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
from whisper_timestamped import transcribe_timestamped
|
||||
from whisperlivekit.whisper import load_model as load_model
|
||||
|
||||
self.transcribe_timestamped = transcribe_timestamped
|
||||
if model_dir is not None:
|
||||
logger.debug("ignoring model_dir, not implemented")
|
||||
return whisper.load_model(model_size, download_root=cache_dir)
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
if resolved_path.is_dir():
|
||||
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
||||
if pytorch_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
resolved_path = pytorch_path
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_model(str(resolved_path))
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_model(model_size, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
result = self.transcribe_timestamped(
|
||||
options = dict(self.transcribe_kargs)
|
||||
options.pop("vad", None)
|
||||
options.pop("vad_filter", None)
|
||||
language = self.original_language if self.original_language else None
|
||||
|
||||
result = whisper_transcribe(
|
||||
self.model,
|
||||
audio,
|
||||
language=self.original_language,
|
||||
language=language,
|
||||
initial_prompt=init_prompt,
|
||||
verbose=None,
|
||||
condition_on_previous_text=True,
|
||||
**self.transcribe_kargs,
|
||||
word_timestamps=True,
|
||||
**options,
|
||||
)
|
||||
return result
|
||||
|
||||
def ts_words(self, r) -> List[ASRToken]:
|
||||
"""
|
||||
Converts the whisper_timestamped result to a list of ASRToken objects.
|
||||
Converts the Whisper result to a list of ASRToken objects.
|
||||
"""
|
||||
tokens = []
|
||||
for segment in r["segments"]:
|
||||
for word in segment["words"]:
|
||||
token = ASRToken(word["start"], word["end"], word["text"])
|
||||
token = ASRToken(
|
||||
word["start"],
|
||||
word["end"],
|
||||
word["word"],
|
||||
probability=word.get("probability"),
|
||||
)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -78,7 +100,7 @@ class WhisperTimestampedASR(ASRBase):
|
||||
return [segment["end"] for segment in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad"] = True
|
||||
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper as the backend."""
|
||||
@@ -88,9 +110,10 @@ class FasterWhisperASR(ASRBase):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
||||
f"model_size and cache_dir parameters are not used.")
|
||||
model_size_or_path = model_dir
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = model_size
|
||||
else:
|
||||
@@ -146,8 +169,9 @@ class MLXWhisper(ASRBase):
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.")
|
||||
model_size_or_path = model_dir
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = self.translate_model_name(model_size)
|
||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
@@ -272,4 +296,4 @@ class OpenaiApiASR(ASRBase):
|
||||
return transcript
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
self.use_vad_opt = True
|
||||
|
||||
@@ -5,13 +5,18 @@ import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||
import platform
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
@@ -70,6 +75,7 @@ def backend_factory(
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
model_path,
|
||||
direct_english_translation,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
@@ -77,27 +83,56 @@ def backend_factory(
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend = backend
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
elif backend == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
backend_choice = backend
|
||||
custom_reference = model_path or model_dir
|
||||
resolved_root = None
|
||||
pytorch_checkpoint = None
|
||||
has_mlx_weights = False
|
||||
has_fw_weights = False
|
||||
|
||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||
if custom_reference:
|
||||
resolved_root = resolve_model_path(custom_reference)
|
||||
if resolved_root.is_dir():
|
||||
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
||||
else:
|
||||
pytorch_checkpoint = resolved_root
|
||||
|
||||
if backend_choice == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
backend_choice = _normalize_backend_choice(
|
||||
backend_choice,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
)
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
elif backend_choice == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
else:
|
||||
asr_cls = WhisperASR
|
||||
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
||||
if custom_reference and model_override is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||
)
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan}...")
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
|
||||
asr = asr_cls(
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_dir,
|
||||
model_dir=model_override,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
@@ -119,4 +154,46 @@ def backend_factory(
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
return asr
|
||||
asr.backend_choice = backend_choice
|
||||
return asr
|
||||
|
||||
|
||||
def _normalize_backend_choice(
|
||||
preferred_backend,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
):
|
||||
backend_choice = preferred_backend
|
||||
|
||||
if backend_choice == "auto":
|
||||
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
|
||||
return "mlx-whisper"
|
||||
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
if backend_choice == "mlx-whisper":
|
||||
if not mlx_backend_available():
|
||||
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
|
||||
if resolved_root is not None and not has_mlx_weights:
|
||||
raise FileNotFoundError(
|
||||
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
|
||||
)
|
||||
if platform.system() != "Darwin":
|
||||
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
if not faster_backend_available():
|
||||
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
|
||||
if resolved_root is not None and not has_fw_weights:
|
||||
raise FileNotFoundError(
|
||||
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
|
||||
)
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "whisper":
|
||||
return backend_choice
|
||||
|
||||
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")
|
||||
|
||||
69
whisperlivekit/model_paths.py
Normal file
69
whisperlivekit/model_paths.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (if present).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
||||
"""
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pytorch_path: Optional[Path] = None
|
||||
|
||||
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
||||
pytorch_path = path
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
if path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename in {"weights.npz", "weights.safetensors"}:
|
||||
compatible_whisper_mlx = True
|
||||
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
||||
compatible_faster_whisper = True
|
||||
elif suffix in {".pt", ".safetensors"}:
|
||||
pytorch_path = file
|
||||
elif filename == "pytorch_model.bin":
|
||||
pytorch_path = file
|
||||
|
||||
if pytorch_path is None:
|
||||
fallback = path / "pytorch_model.bin"
|
||||
if fallback.exists():
|
||||
pytorch_path = fallback
|
||||
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Return a local path for the provided model reference.
|
||||
|
||||
If the path does not exist locally, it is treated as a Hugging Face repo id
|
||||
and downloaded via snapshot_download.
|
||||
"""
|
||||
path = Path(model_path).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
|
||||
return downloaded_path
|
||||
@@ -129,11 +129,18 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
"--backend-policy",
|
||||
type=str,
|
||||
default="simulstreaming",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
choices=["1", "2", "simulstreaming", "localagreement"],
|
||||
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
@@ -316,5 +323,10 @@ def parse_args():
|
||||
args.vad = not args.no_vad
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
if args.backend_policy == "1":
|
||||
args.backend_policy = "simulstreaming"
|
||||
elif args.backend_policy == "2":
|
||||
args.backend_policy = "localagreement"
|
||||
|
||||
return args
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional, Union
|
||||
from typing import List, Tuple, Optional
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
@@ -10,6 +10,11 @@ from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
@@ -18,65 +23,16 @@ from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: `pip install mlx-whisper`\n{"="*50}""")
|
||||
HAS_MLX_WHISPER = False
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
else:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() != "Darwin":
|
||||
print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""")
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]):
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pytorch_path = None
|
||||
if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']:
|
||||
pytorch_path = path
|
||||
elif path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if file.is_file():
|
||||
if file.name in ['weights.npz', "weights.safetensors"]:
|
||||
compatible_whisper_mlx = True
|
||||
elif file.suffix.lower() == '.bin':
|
||||
compatible_faster_whisper = True
|
||||
elif file.suffix.lower() == '.pt':
|
||||
pytorch_path = file
|
||||
elif file.suffix.lower() == '.safetensors':
|
||||
pytorch_path = file
|
||||
if pytorch_path is None:
|
||||
if (path / Path("pytorch_model.bin")).exists():
|
||||
pytorch_path = path / Path("pytorch_model.bin")
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: str) -> Path:
|
||||
path = Path(model_path)
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc:
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=model_path))
|
||||
return downloaded_path
|
||||
|
||||
mlx_model_mapping = {}
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
if HAS_FASTER_WHISPER:
|
||||
from faster_whisper import WhisperModel
|
||||
else:
|
||||
WhisperModel = None
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
@@ -194,13 +150,22 @@ class SimulStreamingASR():
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||
self.model_name = self.pytorch_path.stem
|
||||
is_multilingual = not self.model_path.endswith(".en")
|
||||
if self.pytorch_path:
|
||||
self.model_name = self.pytorch_path.stem
|
||||
else:
|
||||
self.model_name = Path(self.model_path).stem
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
@@ -217,7 +182,19 @@ class SimulStreamingASR():
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_name = self.model_size
|
||||
is_multilingual = not self.model_name.endswith(".en")
|
||||
else:
|
||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||
|
||||
is_multilingual = not self.model_name.endswith(".en")
|
||||
|
||||
self.encoder_backend = self._resolve_encoder_backend(
|
||||
preferred_backend,
|
||||
compatible_whisper_mlx,
|
||||
compatible_faster_whisper,
|
||||
)
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
@@ -246,32 +223,72 @@ class SimulStreamingASR():
|
||||
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
if self.model_path and compatible_whisper_mlx:
|
||||
mlx_model = self.model_path
|
||||
else:
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if mlx_model:
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self.model_path and compatible_faster_whisper:
|
||||
fw_model = self.model_path
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
if self.encoder_backend == "mlx-whisper":
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.fast_encoder = True
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
choice = preferred_backend or "auto"
|
||||
if self.disable_fast_encoder:
|
||||
return "whisper"
|
||||
if choice == "whisper":
|
||||
return "whisper"
|
||||
if choice == "mlx-whisper":
|
||||
if not self._can_use_mlx(compatible_whisper_mlx):
|
||||
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
||||
return "mlx-whisper"
|
||||
if choice == "faster-whisper":
|
||||
if not self._can_use_faster(compatible_faster_whisper):
|
||||
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
||||
return "faster-whisper"
|
||||
if choice == "openai-api":
|
||||
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
||||
# auto mode
|
||||
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
||||
return "mlx-whisper"
|
||||
if self._can_use_faster(compatible_faster_whisper):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
def _has_custom_model_path(self):
|
||||
return self._resolved_model_path is not None
|
||||
|
||||
def _can_use_mlx(self, compatible_whisper_mlx):
|
||||
if not HAS_MLX_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_whisper_mlx
|
||||
return self.model_name in mlx_model_mapping
|
||||
|
||||
def _can_use_faster(self, compatible_faster_whisper):
|
||||
if not HAS_FASTER_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_faster_whisper
|
||||
return True
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(
|
||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||
|
||||
@@ -17,6 +17,10 @@ from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
@@ -26,21 +30,18 @@ DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
HAS_MLX_WHISPER = False
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
if faster_backend_available():
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user