mlx/fasterWhisper encoders are loaded once and shared in simulstreaming

This commit is contained in:
Quentin Fuxa
2025-08-31 12:33:19 +02:00
parent d467716e26
commit d5008ed828
6 changed files with 132 additions and 53 deletions

View File

@@ -10,6 +10,12 @@ On macOS Apple Silicon M4 :
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
Memory saved by only loading encoder for optimized framework:
For tiny.en, mlx whisper:
Sizes MLX whisper:
Decoder weights: 59110771 bytes
Encoder weights: 15268874 bytes

Binary file not shown.

Before

Width:  |  Height:  |  Size: 388 KiB

After

Width:  |  Height:  |  Size: 355 KiB

View File

@@ -18,6 +18,21 @@ from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor:
@@ -46,7 +61,10 @@ class SimulStreamingOnlineProcessor:
model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg,
loaded_model=model)
loaded_model=model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def insert_silence(self, silence_duration, offset):
"""
@@ -273,8 +291,18 @@ class SimulStreamingASR():
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
self.mlx_encoder, self.fw_encoder = None, None
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path)

View File

@@ -0,0 +1,72 @@
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here.
# Size examples: for tiny.en,
# Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes
encoder_weights = {}
encoder_weights['encoder'] = weights['encoder']
del(weights)
model.update(encoder_weights)
mx.eval(model.parameters())
return model

View File

@@ -1,15 +0,0 @@
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}

View File

@@ -23,29 +23,22 @@ from .generation_progress import *
DEC_PAD = 50257
logger = logging.getLogger(__name__)
import sys
import wave
try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from mlx_whisper import load_models
from .mlx_model_mapping import model_mapping
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
try:
from faster_whisper import WhisperModel
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
# HAS_MLX_WHISPER = False
HAS_FASTER_WHISPER = False #Time to determine if that's really faster
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
@@ -54,7 +47,13 @@ HAS_FASTER_WHISPER = False #Time to determine if that's really faster
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
def __init__(
self,
cfg: AlignAttConfig,
loaded_model=None,
mlx_encoder=None,
fw_encoder=None,
) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
@@ -63,19 +62,11 @@ class PaddedAlignAttWhisper:
else:
self.model = load_model(name=model_name, download_root=model_path)
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = model_mapping[model_name]
self.mlx_model = load_models.load_model(path_or_hf_repo=mlx_model_name)
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_model = WhisperModel(
model_name,
device='auto',
compute_type='auto',
)
self.feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if HAS_FASTER_WHISPER:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
@@ -398,17 +389,16 @@ class PaddedAlignAttWhisper:
if HAS_MLX_WHISPER:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_model.encoder(mlx_mel[None])
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu'
elif HAS_FASTER_WHISPER:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
# padded_audio = pad_or_trim(input_segments.detach(), N_SAMPLES)
mel_padded_2 = self.feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_model.encode(mel)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
device = 'cpu'
else:
@@ -423,8 +413,6 @@ class PaddedAlignAttWhisper:
device = mel.device
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):