mlx-whisper used as simulstreaming encoder: improve speed for macos systems

This commit is contained in:
Quentin Fuxa
2025-08-30 22:19:11 +02:00
parent 4a71a391b8
commit 1d926f2e67
3 changed files with 58 additions and 23 deletions

View File

@@ -13,15 +13,10 @@ import os
import gc
logger = logging.getLogger(__name__)
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
# TOO_MANY_REPETITIONS = 3

View File

@@ -0,0 +1,15 @@
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}

View File

@@ -14,7 +14,7 @@ from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens,
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from time import time
from .token_buffer import TokenBuffer
import numpy as np
@@ -26,6 +26,16 @@ logger = logging.getLogger(__name__)
import sys
import wave
try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from mlx_whisper import load_models
from .mlx_model_mapping import model_mapping
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
# - translation support
@@ -42,6 +52,11 @@ class PaddedAlignAttWhisper:
else:
self.model = load_model(name=model_name, download_root=model_path)
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = model_mapping[model_name]
self.mlx_model = load_models.load_model(path_or_hf_repo=mlx_model_name)
logger.info(f"Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
@@ -359,20 +374,30 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
# beg_encode = time()
if HAS_MLX_WHISPER:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_model.encoder(mlx_mel[None])
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu'
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
device = mel.device
# end_encode = time()
# print('Encode time whisper', HAS_MLX_WHISPER, end_encode-beg_encode)
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
# encode
encoder_feature = self.model.encoder(mel)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
@@ -397,7 +422,7 @@ class PaddedAlignAttWhisper:
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
completed = False
attn_of_alignment_heads = None