mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
warning when transcribe_kargs are used with MLX Whisper
This commit is contained in:
@@ -201,7 +201,8 @@ class MLXWhisper(ASRBase):
|
||||
model_dir (str, optional): Direct path to a custom model directory.
|
||||
If specified, it overrides the `modelsize` parameter.
|
||||
"""
|
||||
from mlx_whisper import transcribe
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(
|
||||
@@ -215,6 +216,12 @@ class MLXWhisper(ASRBase):
|
||||
)
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
|
||||
# In mlx_whisper.transcribe, dtype is defined as:
|
||||
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
|
||||
dtype = mx.float16
|
||||
ModelHolder.get_model(model_size_or_path, dtype)
|
||||
return transcribe
|
||||
|
||||
def translate_model_name(self, model_name):
|
||||
@@ -255,6 +262,8 @@ class MLXWhisper(ASRBase):
|
||||
)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
if self.transcribe_kargs:
|
||||
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
||||
segments = self.model(
|
||||
audio,
|
||||
language=self.original_language,
|
||||
@@ -262,7 +271,6 @@ class MLXWhisper(ASRBase):
|
||||
word_timestamps=True,
|
||||
condition_on_previous_text=True,
|
||||
path_or_hf_repo=self.model_size_or_path,
|
||||
**self.transcribe_kargs,
|
||||
)
|
||||
return segments.get("segments", [])
|
||||
|
||||
@@ -844,7 +852,7 @@ def add_shared_args(parser):
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
default="tiny",
|
||||
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
||||
","
|
||||
),
|
||||
@@ -879,14 +887,14 @@ def add_shared_args(parser):
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
default="mlx-whisper",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=True,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -895,7 +903,7 @@ def add_shared_args(parser):
|
||||
parser.add_argument(
|
||||
"--vad",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=True,
|
||||
help="Use VAD = voice activity detection, with the default parameters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -1006,8 +1014,9 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
"--audio_path",
|
||||
type=str,
|
||||
default='samples_jfk.wav',
|
||||
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
|
||||
)
|
||||
add_shared_args(parser)
|
||||
|
||||
Reference in New Issue
Block a user