use platform to determine system and recommand mlx whisper

This commit is contained in:
Quentin Fuxa
2025-09-07 15:49:11 +02:00
parent 72f33be6f2
commit 334b338ab0
3 changed files with 11 additions and 10 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.8"
version = "0.2.8.post1"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [

View File

@@ -3,12 +3,12 @@ import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc
logger = logging.getLogger(__name__)
@@ -22,6 +22,8 @@ try:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False

View File

@@ -61,6 +61,8 @@ class PaddedAlignAttWhisper:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
@@ -401,25 +403,22 @@ class PaddedAlignAttWhisper:
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.as_tensor(mlx_encoder_feature)
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = encoder_feature.device #'cpu' is apple silicon
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
encoder_feature = torch.as_tensor(encoder_feature_ctranslate)
device = encoder_feature.device
encoder_feature_ctranslate = np.array(self.fw_encoder.encode(mel))
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
device=self.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
device = mel.device
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
@@ -447,7 +446,7 @@ class PaddedAlignAttWhisper:
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
completed = False
attn_of_alignment_heads = None
@@ -658,7 +657,7 @@ class PaddedAlignAttWhisper:
### new hypothesis
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.model.device,
device=self.device,
)
self.tokens.append(new_tokens)
# TODO: test if this is redundant or not