diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 71a667e..978329b 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -176,12 +176,10 @@ class TranscriptionEngine: def online_factory(args, asr): - if args.backend_policy == "simulstreaming": + if args.backend_policy == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor - online = SimulStreamingOnlineProcessor(asr) - else: - online = OnlineASRProcessor(asr) - return online + return SimulStreamingOnlineProcessor(asr) + return OnlineASRProcessor(asr) def online_diarization_factory(args, diarization_backend): diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 04a5ea6..e85d436 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -24,9 +24,11 @@ logger = logging.getLogger(__name__) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) if HAS_MLX_WHISPER: - from .mlx_encoder import load_mlx_encoder, mlx_model_mapping + from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping + from .mlx import MLXAlignAtt else: mlx_model_mapping = {} + MLXAlignAtt = None HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) if HAS_FASTER_WHISPER: from faster_whisper import WhisperModel @@ -36,50 +38,49 @@ else: MIN_DURATION_REAL_SILENCE = 5 class SimulStreamingOnlineProcessor: + """Online processor for SimulStreaming ASR.""" SAMPLING_RATE = 16000 - def __init__( - self, - asr, - logfile=sys.stderr, - ): + def __init__(self, asr, logfile=sys.stderr): self.asr = asr self.logfile = logfile self.end = 0.0 self.buffer = [] self.committed: List[ASRToken] = [] - self.last_result_tokens: List[ASRToken] = [] - self.load_new_alignatt_instance() + self.last_result_tokens: List[ASRToken] = [] + self.model = self._create_alignatt() if asr.tokenizer: self.model.tokenizer = asr.tokenizer + self.model.state.tokenizer = asr.tokenizer - def load_new_alignatt_instance(self): - """Initialize AlignAtt decoder using the shared model.""" - self.model = AlignAtt( - cfg=self.asr.cfg, - loaded_model=self.asr.shared_model, - mlx_encoder=self.asr.mlx_encoder, - fw_encoder=self.asr.fw_encoder, - ) + def _create_alignatt(self): + """Create the AlignAtt decoder instance based on ASR mode.""" + if self.asr.use_full_mlx and HAS_MLX_WHISPER: + return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model) + else: + return AlignAtt( + cfg=self.asr.cfg, + loaded_model=self.asr.shared_model, + mlx_encoder=self.asr.mlx_encoder, + fw_encoder=self.asr.fw_encoder, + ) def start_silence(self): tokens, processed_upto = self.process_iter(is_last=True) return tokens, processed_upto def end_silence(self, silence_duration, offset): - """ - Handle silence period. - - If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear. - Otherwise, insert a small silence and shift the last_attend_frame. - """ + """Handle silence period.""" self.end += silence_duration long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE if not long_silence: gap_len = int(16000 * silence_duration) if gap_len > 0: - gap_silence = torch.zeros(gap_len) + if self.asr.use_full_mlx: + gap_silence = np.zeros(gap_len, dtype=np.float32) + else: + gap_silence = torch.zeros(gap_len) self.model.insert_audio(gap_silence) if long_silence: self.model.refresh_segment(complete=True) @@ -87,11 +88,12 @@ class SimulStreamingOnlineProcessor: def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): """Append an audio chunk to be processed by SimulStreaming.""" - - # Convert numpy array to torch tensor - audio_tensor = torch.from_numpy(audio).float() - self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior - self.model.insert_audio(audio_tensor) + self.end = audio_stream_end_time + if self.asr.use_full_mlx: + self.model.insert_audio(audio) + else: + audio_tensor = torch.from_numpy(audio).float() + self.model.insert_audio(audio_tensor) def new_speaker(self, change_speaker: ChangeSpeaker): """Handle speaker change event.""" @@ -130,6 +132,10 @@ class SimulStreamingOnlineProcessor: def warmup(self, audio, init_prompt=""): """Warmup the SimulStreaming model.""" try: + if self.asr.use_full_mlx: + # MLX mode: ensure numpy array + if hasattr(audio, 'numpy'): + audio = audio.numpy() self.model.insert_audio(audio) self.model.infer(True) self.model.refresh_segment(complete=True) @@ -139,9 +145,14 @@ class SimulStreamingOnlineProcessor: def __del__(self): gc.collect() - torch.cuda.empty_cache() + if not getattr(self.asr, 'use_full_mlx', True) and torch is not None: + try: + torch.cuda.empty_cache() + except Exception: + pass -class SimulStreamingASR(): + +class SimulStreamingASR: """SimulStreaming backend with AlignAtt policy.""" sep = "" @@ -158,6 +169,7 @@ class SimulStreamingASR(): self.fast_encoder = False self._resolved_model_path = None self.encoder_backend = "whisper" + self.use_full_mlx = getattr(self, "use_full_mlx", False) preferred_backend = getattr(self, "backend", "auto") compatible_whisper_mlx, compatible_faster_whisper = True, True @@ -170,7 +182,7 @@ class SimulStreamingASR(): compatible_whisper_mlx = model_info.compatible_whisper_mlx compatible_faster_whisper = model_info.compatible_faster_whisper - if not model_info.has_pytorch: + if not self.use_full_mlx and not model_info.has_pytorch: raise FileNotFoundError( f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" ) @@ -190,6 +202,10 @@ class SimulStreamingASR(): self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper") if self.encoder_backend == "whisper": self.disable_fast_encoder = True + + if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin": + if not hasattr(self, '_full_mlx_disabled'): + self.use_full_mlx = True self.cfg = AlignAttConfig( tokenizer_is_multilingual= is_multilingual, @@ -214,20 +230,36 @@ class SimulStreamingASR(): else: self.tokenizer = None - self.mlx_encoder, self.fw_encoder = None, None - if self.encoder_backend == "mlx-whisper": - print('Simulstreaming will use MLX whisper to increase encoding speed.') + self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None + self.shared_model = None + + if self.use_full_mlx and HAS_MLX_WHISPER: + logger.info('MLX Whisper backend used.') if self._resolved_model_path is not None: - mlx_model = str(self._resolved_model_path) + mlx_model_path = str(self._resolved_model_path) else: - mlx_model = mlx_model_mapping.get(self.model_name) - if not mlx_model: + mlx_model_path = mlx_model_mapping.get(self.model_name) + if not mlx_model_path: raise FileNotFoundError( f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'." ) - self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) + self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path) + self._warmup_mlx_model() + elif self.encoder_backend == "mlx-whisper": + # hybrid mode: mlx encoder + pytorch decoder + logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.') + if self._resolved_model_path is not None: + mlx_model_path = str(self._resolved_model_path) + else: + mlx_model_path = mlx_model_mapping.get(self.model_name) + if not mlx_model_path: + raise FileNotFoundError( + f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'." + ) + self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path) + self.shared_model = self.load_model() elif self.encoder_backend == "faster-whisper": - print('Simulstreaming will use Faster Whisper for the encoder.') + print('SimulStreaming will use Faster Whisper for the encoder.') if self._resolved_model_path is not None: fw_model = str(self._resolved_model_path) else: @@ -237,7 +269,20 @@ class SimulStreamingASR(): device='auto', compute_type='auto', ) - self.shared_model = self.load_model() + self.shared_model = self.load_model() + else: + self.shared_model = self.load_model() + + def _warmup_mlx_model(self): + """Warmup the full MLX model.""" + warmup_audio = load_file(self.warmup_file) + if warmup_audio is not None: + temp_model = MLXAlignAtt( + cfg=self.cfg, + mlx_model=self.mlx_model, + ) + temp_model.warmup(warmup_audio) + logger.info("Full MLX model warmed up successfully") def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):