From 1d926f2e675d96041b9a654005f5c820c5bdfcc9 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 30 Aug 2025 22:19:11 +0200 Subject: [PATCH] mlx-whisper used as simulstreaming encoder: improve speed for macos systems --- whisperlivekit/simul_whisper/backend.py | 13 ++--- .../simul_whisper/mlx_model_mapping.py | 15 ++++++ whisperlivekit/simul_whisper/simul_whisper.py | 53 ++++++++++++++----- 3 files changed, 58 insertions(+), 23 deletions(-) create mode 100644 whisperlivekit/simul_whisper/mlx_model_mapping.py diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 4d3eaa6..4f90290 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -13,15 +13,10 @@ import os import gc logger = logging.getLogger(__name__) -try: - import torch - from whisperlivekit.simul_whisper.config import AlignAttConfig - from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper - from whisperlivekit.simul_whisper.whisper import tokenizer -except ImportError as e: - raise ImportError( - """SimulStreaming dependencies are not available. - Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""") +import torch +from whisperlivekit.simul_whisper.config import AlignAttConfig +from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper +from whisperlivekit.simul_whisper.whisper import tokenizer # TOO_MANY_REPETITIONS = 3 diff --git a/whisperlivekit/simul_whisper/mlx_model_mapping.py b/whisperlivekit/simul_whisper/mlx_model_mapping.py new file mode 100644 index 0000000..21d61c6 --- /dev/null +++ b/whisperlivekit/simul_whisper/mlx_model_mapping.py @@ -0,0 +1,15 @@ +model_mapping = { + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "tiny": "mlx-community/whisper-tiny-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v1": "mlx-community/whisper-large-v1-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "large": "mlx-community/whisper-large-mlx", +} \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index cad2143..104f379 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -14,7 +14,7 @@ from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, from .beam import BeamPyTorchInference from .eow_detection import fire_at_boundary, load_cif import os - +from time import time from .token_buffer import TokenBuffer import numpy as np @@ -26,6 +26,16 @@ logger = logging.getLogger(__name__) import sys import wave +try: + from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram + from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim + from mlx_whisper import load_models + from .mlx_model_mapping import model_mapping + HAS_MLX_WHISPER = True +except ImportError: + HAS_MLX_WHISPER = False + + # New features added to the original version of Simul-Whisper: # - large-v3 model support # - translation support @@ -42,6 +52,11 @@ class PaddedAlignAttWhisper: else: self.model = load_model(name=model_name, download_root=model_path) + if HAS_MLX_WHISPER: + print('Simulstreaming will use MLX whisper for a faster encoder.') + mlx_model_name = model_mapping[model_name] + self.mlx_model = load_models.load_model(path_or_hf_repo=mlx_model_name) + logger.info(f"Model dimensions: {self.model.dims}") self.decode_options = DecodingOptions( @@ -359,20 +374,30 @@ class PaddedAlignAttWhisper: else: input_segments = self.segments[0] + # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder + # beg_encode = time() + if HAS_MLX_WHISPER: + mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) + mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) + mlx_encoder_feature = self.mlx_model.encoder(mlx_mel[None]) + encoder_feature = torch.tensor(np.array(mlx_encoder_feature)) + content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) + device = 'cpu' + else: + # mel + padding to 30s + mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, + device=self.model.device).unsqueeze(0) + # trim to 3000 + mel = pad_or_trim(mel_padded, N_FRAMES) + # the len of actual audio + content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) + encoder_feature = self.model.encoder(mel) + device = mel.device + # end_encode = time() + # print('Encode time whisper', HAS_MLX_WHISPER, end_encode-beg_encode) - # mel + padding to 30s - mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, - device=self.model.device).unsqueeze(0) - # trim to 3000 - mel = pad_or_trim(mel_padded, N_FRAMES) - - # the len of actual audio - content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) - - # encode - encoder_feature = self.model.encoder(mel) - + # logger.debug(f"Encoder feature shape: {encoder_feature.shape}") # if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): # logger.debug("mel ") @@ -397,7 +422,7 @@ class PaddedAlignAttWhisper: ####################### Decoding loop logger.info("Decoding loop starts\n") - sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device) + sum_logprobs = torch.zeros(self.cfg.beam_size, device=device) completed = False attn_of_alignment_heads = None