diff --git a/pyproject.toml b/pyproject.toml index e943cf9..3129ee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "whisperlivekit" -version = "0.2.8" +version = "0.2.8.post1" description = "Real-time speech-to-text with speaker diarization using Whisper" readme = "README.md" authors = [ diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 0b9a74a..bcdef9b 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -3,12 +3,12 @@ import numpy as np import logging from typing import List, Tuple, Optional import logging +import platform from whisperlivekit.timed_objects import ASRToken, Transcript from whisperlivekit.warmup import load_file from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE from .whisper import load_model, tokenizer from .whisper.audio import TOKENS_PER_SECOND - import os import gc logger = logging.getLogger(__name__) @@ -22,6 +22,8 @@ try: from .mlx_encoder import mlx_model_mapping, load_mlx_encoder HAS_MLX_WHISPER = True except ImportError: + if platform.system() == "Darwin" and platform.machine() == "arm64": + print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper') HAS_MLX_WHISPER = False if HAS_MLX_WHISPER: HAS_FASTER_WHISPER = False diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 9821360..3b0e7c5 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -61,6 +61,8 @@ class PaddedAlignAttWhisper: self.model = loaded_model else: self.model = load_model(name=model_name, download_root=model_path) + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.mlx_encoder = mlx_encoder self.fw_encoder = fw_encoder @@ -401,25 +403,22 @@ class PaddedAlignAttWhisper: mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) encoder_feature = torch.as_tensor(mlx_encoder_feature) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) - device = encoder_feature.device #'cpu' is apple silicon elif self.fw_encoder: audio_length_seconds = len(input_segments) / 16000 content_mel_len = int(audio_length_seconds * 100)//2 mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1) - encoder_feature_ctranslate = self.fw_encoder.encode(mel) - encoder_feature = torch.as_tensor(encoder_feature_ctranslate) - device = encoder_feature.device + encoder_feature_ctranslate = np.array(self.fw_encoder.encode(mel)) + encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device) else: # mel + padding to 30s mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, - device=self.model.device).unsqueeze(0) + device=self.device).unsqueeze(0) # trim to 3000 mel = pad_or_trim(mel_padded, N_FRAMES) # the len of actual audio content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) encoder_feature = self.model.encoder(mel) - device = mel.device end_encode = time() # print('Encoder duration:', end_encode-beg_encode) @@ -447,7 +446,7 @@ class PaddedAlignAttWhisper: ####################### Decoding loop logger.info("Decoding loop starts\n") - sum_logprobs = torch.zeros(self.cfg.beam_size, device=device) + sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device) completed = False attn_of_alignment_heads = None @@ -658,7 +657,7 @@ class PaddedAlignAttWhisper: ### new hypothesis logger.debug(f"new_hypothesis: {new_hypothesis}") new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to( - device=self.model.device, + device=self.device, ) self.tokens.append(new_tokens) # TODO: test if this is redundant or not