diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index e816ed8..9b3368b 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -210,11 +210,15 @@ class SimulStreamingASR(): else: self.tokenizer = None - self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") - self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) + if model_dir: + self.model_name = model_dir + self.model_path = None + else: + self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") + self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) self.mlx_encoder, self.fw_encoder = None, None - if not self.disable_fast_encoder: + if not self.disable_fast_encoder and not model_dir: if HAS_MLX_WHISPER: print('Simulstreaming will use MLX whisper for a faster encoder.') mlx_model_name = mlx_model_mapping[self.model_name]