add translate_model_name function

This commit is contained in:
Quentin Fuxa
2024-12-19 11:10:02 +01:00
parent 87cab7c280
commit 8dcebd9329

View File

@@ -160,27 +160,71 @@ class MLXWhisper(ASRBase):
"""
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
Significantly faster than faster-whisper (without CUDA) on Apple M1.
"""
sep = " "
def load_model(self, modelsize=None, model_dir=None):
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
"""
Loads the MLX-compatible Whisper model.
Args:
modelsize (str, optional): The size or name of the Whisper model to load.
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
cache_dir (str, optional): Path to the directory for caching models.
**Note**: This is not supported by MLX Whisper and will be ignored.
model_dir (str, optional): Direct path to a custom model directory.
If specified, it overrides the `modelsize` parameter.
"""
from mlx_whisper import transcribe
if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
model_size_or_path = model_dir
elif modelsize is not None:
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
model_size_or_path = modelsize
elif modelsize == None:
logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
model_size_or_path = "mlx-community/whisper-large-v3-mlx"
model_size_or_path = self.translate_model_name(modelsize)
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
self.model_size_or_path = model_size_or_path
return transcribe
def translate_model_name(self, model_name):
"""
Translates a given model name to its corresponding MLX-compatible model path.
Args:
model_name (str): The name of the model to translate.
Returns:
str: The MLX-compatible model path.
"""
# Dictionary mapping model names to MLX-compatible paths
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx"
}
# Retrieve the corresponding MLX model path
mlx_model_path = model_mapping.get(model_name)
if mlx_model_path:
return mlx_model_path
else:
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
def transcribe(self, audio, init_prompt=""):
segments = self.model(
audio,