From a38c103fcddf529017237f6a1fa42d660de220ef Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 16 Nov 2025 21:24:14 +0100 Subject: [PATCH] simulstreaming coreml encoder compatibility --- whisperlivekit/simul_whisper/simul_whisper.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 4ba900d..bbcbf07 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -43,6 +43,23 @@ if faster_backend_available(): from faster_whisper.feature_extractor import FeatureExtractor HAS_FASTER_WHISPER = True +USE_MLCORE = False + + +def load_coreml_encoder(): + try: + from coremltools.models import MLModel + except ImportError: + logger.warning("coremltools is not installed") + return None + COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage") + _coreml_encoder = MLModel(COREML_ENCODER_PATH) + spec = _coreml_encoder.get_spec() + _coreml_input_name = spec.description.input[0].name if spec.description.input else "mel" + _coreml_output_name = spec.description.output[0].name if spec.description.output else None + return _coreml_encoder, _coreml_input_name, _coreml_output_name + + class PaddedAlignAttWhisper: def __init__( self, @@ -58,6 +75,10 @@ class PaddedAlignAttWhisper: self.fw_encoder = fw_encoder if fw_encoder: self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) + self.coreml_encoder_tuple = None + if USE_MLCORE: + self.coreml_encoder_tuple = load_coreml_encoder() + self.use_mlcore = self.coreml_encoder_tuple is not None self.device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -402,6 +423,27 @@ class PaddedAlignAttWhisper: # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder beg_encode = time() + if self.use_mlcore: + coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple + mel_padded = log_mel_spectrogram( + input_segments, + n_mels=self.model.dims.n_mels, + padding=N_SAMPLES, + device="cpu", + ).unsqueeze(0) + mel = pad_or_trim(mel_padded, N_FRAMES) + content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2) + mel_np = np.ascontiguousarray(mel.numpy()) + ml_inputs = {coreml_input_name or "mel": mel_np} + coreml_outputs = coreml_encoder.predict(ml_inputs) + if coreml_output_name and coreml_output_name in coreml_outputs: + encoder_feature_np = coreml_outputs[coreml_output_name] + else: + encoder_feature_np = next(iter(coreml_outputs.values())) + encoder_feature = torch.as_tensor( + np.array(encoder_feature_np), + device=self.device, + ) if self.mlx_encoder: mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) @@ -430,7 +472,7 @@ class PaddedAlignAttWhisper: content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) encoder_feature = self.model.encoder(mel) end_encode = time() - # print('Encoder duration:', end_encode-beg_encode) + print('Encoder duration:', end_encode-beg_encode) if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp: seconds_since_start = self.segments_len() - self.first_timestamp