From 8dcebd9329a6d4a033e2d7f8624dd2e1eae1ecc7 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Thu, 19 Dec 2024 11:10:02 +0100 Subject: [PATCH] add translate_model_name function --- whisper_online.py | 58 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/whisper_online.py b/whisper_online.py index 53c8417..ce61a52 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -160,27 +160,71 @@ class MLXWhisper(ASRBase): """ Uses MPX Whisper library as the backend, optimized for Apple Silicon. Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc - Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx + Significantly faster than faster-whisper (without CUDA) on Apple M1. """ sep = " " - def load_model(self, modelsize=None, model_dir=None): + def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + """ + Loads the MLX-compatible Whisper model. + + Args: + modelsize (str, optional): The size or name of the Whisper model to load. + If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method. + Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo". + cache_dir (str, optional): Path to the directory for caching models. + **Note**: This is not supported by MLX Whisper and will be ignored. + model_dir (str, optional): Direct path to a custom model directory. + If specified, it overrides the `modelsize` parameter. + """ from mlx_whisper import transcribe if model_dir is not None: logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") model_size_or_path = model_dir elif modelsize is not None: - logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.") - model_size_or_path = modelsize - elif modelsize == None: - logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.") - model_size_or_path = "mlx-community/whisper-large-v3-mlx" + model_size_or_path = self.translate_model_name(modelsize) + logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") self.model_size_or_path = model_size_or_path return transcribe + def translate_model_name(self, model_name): + """ + Translates a given model name to its corresponding MLX-compatible model path. + + Args: + model_name (str): The name of the model to translate. + + Returns: + str: The MLX-compatible model path. + """ + # Dictionary mapping model names to MLX-compatible paths + model_mapping = { + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "tiny": "mlx-community/whisper-tiny-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v1": "mlx-community/whisper-large-v1-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "large": "mlx-community/whisper-large-mlx" + } + + # Retrieve the corresponding MLX model path + mlx_model_path = model_mapping.get(model_name) + + if mlx_model_path: + return mlx_model_path + else: + raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") + def transcribe(self, audio, init_prompt=""): segments = self.model( audio,