warning when transcribe_kargs are used with MLX Whisper

This commit is contained in:
Quentin Fuxa
2025-01-14 20:14:16 +01:00
parent 0ff6067f37
commit f884d1162d

View File

@@ -201,7 +201,8 @@ class MLXWhisper(ASRBase):
model_dir (str, optional): Direct path to a custom model directory.
If specified, it overrides the `modelsize` parameter.
"""
from mlx_whisper import transcribe
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx
if model_dir is not None:
logger.debug(
@@ -215,6 +216,12 @@ class MLXWhisper(ASRBase):
)
self.model_size_or_path = model_size_or_path
# In mlx_whisper.transcribe, dtype is defined as:
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
dtype = mx.float16
ModelHolder.get_model(model_size_or_path, dtype)
return transcribe
def translate_model_name(self, model_name):
@@ -255,6 +262,8 @@ class MLXWhisper(ASRBase):
)
def transcribe(self, audio, init_prompt=""):
if self.transcribe_kargs:
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
segments = self.model(
audio,
language=self.original_language,
@@ -262,7 +271,6 @@ class MLXWhisper(ASRBase):
word_timestamps=True,
condition_on_previous_text=True,
path_or_hf_repo=self.model_size_or_path,
**self.transcribe_kargs,
)
return segments.get("segments", [])
@@ -844,7 +852,7 @@ def add_shared_args(parser):
parser.add_argument(
"--model",
type=str,
default="large-v2",
default="tiny",
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
","
),
@@ -879,14 +887,14 @@ def add_shared_args(parser):
parser.add_argument(
"--backend",
type=str,
default="faster-whisper",
default="mlx-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
action="store_true",
default=False,
default=True,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
)
parser.add_argument(
@@ -895,7 +903,7 @@ def add_shared_args(parser):
parser.add_argument(
"--vad",
action="store_true",
default=False,
default=True,
help="Use VAD = voice activity detection, with the default parameters.",
)
parser.add_argument(
@@ -1006,8 +1014,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_path",
"--audio_path",
type=str,
default='samples_jfk.wav',
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
)
add_shared_args(parser)