mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Merge branch 'whisper-mlx'
This commit is contained in:
@@ -160,27 +160,71 @@ class MLXWhisper(ASRBase):
|
||||
"""
|
||||
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
||||
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
||||
Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
|
||||
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
||||
"""
|
||||
|
||||
sep = " "
|
||||
|
||||
def load_model(self, modelsize=None, model_dir=None):
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
"""
|
||||
Loads the MLX-compatible Whisper model.
|
||||
|
||||
Args:
|
||||
modelsize (str, optional): The size or name of the Whisper model to load.
|
||||
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
|
||||
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
|
||||
cache_dir (str, optional): Path to the directory for caching models.
|
||||
**Note**: This is not supported by MLX Whisper and will be ignored.
|
||||
model_dir (str, optional): Direct path to a custom model directory.
|
||||
If specified, it overrides the `modelsize` parameter.
|
||||
"""
|
||||
from mlx_whisper import transcribe
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
|
||||
model_size_or_path = modelsize
|
||||
elif modelsize == None:
|
||||
logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
|
||||
model_size_or_path = "mlx-community/whisper-large-v3-mlx"
|
||||
model_size_or_path = self.translate_model_name(modelsize)
|
||||
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
return transcribe
|
||||
|
||||
def translate_model_name(self, model_name):
|
||||
"""
|
||||
Translates a given model name to its corresponding MLX-compatible model path.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to translate.
|
||||
|
||||
Returns:
|
||||
str: The MLX-compatible model path.
|
||||
"""
|
||||
# Dictionary mapping model names to MLX-compatible paths
|
||||
model_mapping = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx",
|
||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx"
|
||||
}
|
||||
|
||||
# Retrieve the corresponding MLX model path
|
||||
mlx_model_path = model_mapping.get(model_name)
|
||||
|
||||
if mlx_model_path:
|
||||
return mlx_model_path
|
||||
else:
|
||||
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
segments = self.model(
|
||||
audio,
|
||||
|
||||
Reference in New Issue
Block a user