diff --git a/DEV_NOTES.md b/DEV_NOTES.md index 9bf286d..c41016f 100644 --- a/DEV_NOTES.md +++ b/DEV_NOTES.md @@ -10,6 +10,12 @@ On macOS Apple Silicon M4 : | FASTER_WHISPER | 0.4s | 1.20s | | MLX_WHISPER | 0.07s | 0.20s | +Memory saved by only loading encoder for optimized framework: + +For tiny.en, mlx whisper: +Sizes MLX whisper: +Decoder weights: 59110771 bytes +Encoder weights: 15268874 bytes diff --git a/architecture.png b/architecture.png index 3797148..4030daa 100644 Binary files a/architecture.png and b/architecture.png differ diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 4f90290..7e1482e 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -18,6 +18,21 @@ from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper from whisperlivekit.simul_whisper.whisper import tokenizer +try: + from .mlx_encoder import mlx_model_mapping, load_mlx_encoder + HAS_MLX_WHISPER = True +except ImportError: + HAS_MLX_WHISPER = False +if HAS_MLX_WHISPER: + HAS_FASTER_WHISPER = False +else: + try: + from faster_whisper import WhisperModel + HAS_FASTER_WHISPER = True + except ImportError: + HAS_FASTER_WHISPER = False + + # TOO_MANY_REPETITIONS = 3 class SimulStreamingOnlineProcessor: @@ -46,7 +61,10 @@ class SimulStreamingOnlineProcessor: model = self.asr.get_new_model_instance() self.model = PaddedAlignAttWhisper( cfg=self.asr.cfg, - loaded_model=model) + loaded_model=model, + mlx_encoder=self.asr.mlx_encoder, + fw_encoder=self.asr.fw_encoder, + ) def insert_silence(self, silence_duration, offset): """ @@ -273,8 +291,18 @@ class SimulStreamingASR(): self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) self.models = [self.load_model() for i in range(self.preload_model_count)] - - + self.mlx_encoder, self.fw_encoder = None, None + if HAS_MLX_WHISPER: + print('Simulstreaming will use MLX whisper for a faster encoder.') + mlx_model_name = mlx_model_mapping[self.model_name] + self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name) + elif HAS_FASTER_WHISPER: + print('Simulstreaming will use Faster Whisper for the encoder.') + self.fw_encoder = WhisperModel( + self.model_name, + device='auto', + compute_type='auto', + ) def load_model(self): whisper_model = load_model(name=self.model_name, download_root=self.model_path) diff --git a/whisperlivekit/simul_whisper/mlx_encoder.py b/whisperlivekit/simul_whisper/mlx_encoder.py new file mode 100644 index 0000000..441166b --- /dev/null +++ b/whisperlivekit/simul_whisper/mlx_encoder.py @@ -0,0 +1,72 @@ +import json +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_unflatten + +from mlx_whisper import whisper + +mlx_model_mapping = { + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "tiny": "mlx-community/whisper-tiny-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v1": "mlx-community/whisper-large-v1-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "large": "mlx-community/whisper-large-mlx", +} + +def load_mlx_encoder( + path_or_hf_repo: str, + dtype: mx.Dtype = mx.float32, +) -> whisper.Whisper: + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path(snapshot_download(repo_id=path_or_hf_repo)) + + with open(str(model_path / "config.json"), "r") as f: + config = json.loads(f.read()) + config.pop("model_type", None) + quantization = config.pop("quantization", None) + + model_args = whisper.ModelDimensions(**config) + + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) + + model = whisper.Whisper(model_args, dtype) + + if quantization is not None: + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) + + weights = tree_unflatten(list(weights.items())) + + # we only want to load the encoder weights here. + # Size examples: for tiny.en, + # Decoder weights: 59110771 bytes + # Encoder weights: 15268874 bytes + + + encoder_weights = {} + encoder_weights['encoder'] = weights['encoder'] + del(weights) + + + + model.update(encoder_weights) + mx.eval(model.parameters()) + return model \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/mlx_model_mapping.py b/whisperlivekit/simul_whisper/mlx_model_mapping.py deleted file mode 100644 index 21d61c6..0000000 --- a/whisperlivekit/simul_whisper/mlx_model_mapping.py +++ /dev/null @@ -1,15 +0,0 @@ -model_mapping = { - "tiny.en": "mlx-community/whisper-tiny.en-mlx", - "tiny": "mlx-community/whisper-tiny-mlx", - "base.en": "mlx-community/whisper-base.en-mlx", - "base": "mlx-community/whisper-base-mlx", - "small.en": "mlx-community/whisper-small.en-mlx", - "small": "mlx-community/whisper-small-mlx", - "medium.en": "mlx-community/whisper-medium.en-mlx", - "medium": "mlx-community/whisper-medium-mlx", - "large-v1": "mlx-community/whisper-large-v1-mlx", - "large-v2": "mlx-community/whisper-large-v2-mlx", - "large-v3": "mlx-community/whisper-large-v3-mlx", - "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", - "large": "mlx-community/whisper-large-mlx", -} \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 50719ae..39101bf 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -23,29 +23,22 @@ from .generation_progress import * DEC_PAD = 50257 logger = logging.getLogger(__name__) -import sys -import wave try: from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim - from mlx_whisper import load_models - from .mlx_model_mapping import model_mapping HAS_MLX_WHISPER = True except ImportError: HAS_MLX_WHISPER = False - -try: - from faster_whisper import WhisperModel - from faster_whisper.audio import pad_or_trim as fw_pad_or_trim - from faster_whisper.feature_extractor import FeatureExtractor - HAS_FASTER_WHISPER = True -except ImportError: +if HAS_MLX_WHISPER: HAS_FASTER_WHISPER = False - -# HAS_MLX_WHISPER = False -HAS_FASTER_WHISPER = False #Time to determine if that's really faster - +else: + try: + from faster_whisper.audio import pad_or_trim as fw_pad_or_trim + from faster_whisper.feature_extractor import FeatureExtractor + HAS_FASTER_WHISPER = True + except ImportError: + HAS_FASTER_WHISPER = False # New features added to the original version of Simul-Whisper: # - large-v3 model support @@ -54,7 +47,13 @@ HAS_FASTER_WHISPER = False #Time to determine if that's really faster # - prompt -- static vs. non-static # - context class PaddedAlignAttWhisper: - def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None: + def __init__( + self, + cfg: AlignAttConfig, + loaded_model=None, + mlx_encoder=None, + fw_encoder=None, + ) -> None: self.log_segments = 0 model_name = os.path.basename(cfg.model_path).replace(".pt", "") model_path = os.path.dirname(os.path.abspath(cfg.model_path)) @@ -63,19 +62,11 @@ class PaddedAlignAttWhisper: else: self.model = load_model(name=model_name, download_root=model_path) - if HAS_MLX_WHISPER: - print('Simulstreaming will use MLX whisper for a faster encoder.') - mlx_model_name = model_mapping[model_name] - self.mlx_model = load_models.load_model(path_or_hf_repo=mlx_model_name) - elif HAS_FASTER_WHISPER: - print('Simulstreaming will use Faster Whisper for the encoder.') - self.fw_model = WhisperModel( - model_name, - device='auto', - compute_type='auto', - ) - self.feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) - + self.mlx_encoder = mlx_encoder + self.fw_encoder = fw_encoder + if HAS_FASTER_WHISPER: + self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) + logger.info(f"Model dimensions: {self.model.dims}") self.decode_options = DecodingOptions( @@ -398,17 +389,16 @@ class PaddedAlignAttWhisper: if HAS_MLX_WHISPER: mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) - mlx_encoder_feature = self.mlx_model.encoder(mlx_mel[None]) + mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) encoder_feature = torch.tensor(np.array(mlx_encoder_feature)) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) device = 'cpu' elif HAS_FASTER_WHISPER: audio_length_seconds = len(input_segments) / 16000 content_mel_len = int(audio_length_seconds * 100)//2 - # padded_audio = pad_or_trim(input_segments.detach(), N_SAMPLES) - mel_padded_2 = self.feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] + mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1) - encoder_feature_ctranslate = self.fw_model.encode(mel) + encoder_feature_ctranslate = self.fw_encoder.encode(mel) encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate)) device = 'cpu' else: @@ -423,8 +413,6 @@ class PaddedAlignAttWhisper: device = mel.device end_encode = time() # print('Encoder duration:', end_encode-beg_encode) - - # logger.debug(f"Encoder feature shape: {encoder_feature.shape}") # if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):