diff --git a/whisper_online.py b/whisper_online.py index 8fadec1..7b438ab 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -201,7 +201,8 @@ class MLXWhisper(ASRBase): model_dir (str, optional): Direct path to a custom model directory. If specified, it overrides the `modelsize` parameter. """ - from mlx_whisper import transcribe + from mlx_whisper.transcribe import ModelHolder, transcribe + import mlx.core as mx if model_dir is not None: logger.debug( @@ -215,6 +216,12 @@ class MLXWhisper(ASRBase): ) self.model_size_or_path = model_size_or_path + + # In mlx_whisper.transcribe, dtype is defined as: + # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 + # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16 + dtype = mx.float16 + ModelHolder.get_model(model_size_or_path, dtype) return transcribe def translate_model_name(self, model_name): @@ -255,6 +262,8 @@ class MLXWhisper(ASRBase): ) def transcribe(self, audio, init_prompt=""): + if self.transcribe_kargs: + logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.") segments = self.model( audio, language=self.original_language, @@ -262,7 +271,6 @@ class MLXWhisper(ASRBase): word_timestamps=True, condition_on_previous_text=True, path_or_hf_repo=self.model_size_or_path, - **self.transcribe_kargs, ) return segments.get("segments", []) @@ -844,7 +852,7 @@ def add_shared_args(parser): parser.add_argument( "--model", type=str, - default="large-v2", + default="tiny", choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split( "," ), @@ -879,14 +887,14 @@ def add_shared_args(parser): parser.add_argument( "--backend", type=str, - default="faster-whisper", + default="mlx-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"], help="Load only this backend for Whisper processing.", ) parser.add_argument( "--vac", action="store_true", - default=False, + default=True, help="Use VAC = voice activity controller. Recommended. Requires torch.", ) parser.add_argument( @@ -895,7 +903,7 @@ def add_shared_args(parser): parser.add_argument( "--vad", action="store_true", - default=False, + default=True, help="Use VAD = voice activity detection, with the default parameters.", ) parser.add_argument( @@ -1006,8 +1014,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "audio_path", + "--audio_path", type=str, + default='samples_jfk.wav', help="Filename of 16kHz mono channel wav, on which live streaming is simulated.", ) add_shared_args(parser)