diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index c1f8c2e..9821360 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -399,17 +399,17 @@ class PaddedAlignAttWhisper: mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) - encoder_feature = torch.tensor(np.array(mlx_encoder_feature)) + encoder_feature = torch.as_tensor(mlx_encoder_feature) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) - device = 'cpu' + device = encoder_feature.device #'cpu' is apple silicon elif self.fw_encoder: audio_length_seconds = len(input_segments) / 16000 content_mel_len = int(audio_length_seconds * 100)//2 mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1) encoder_feature_ctranslate = self.fw_encoder.encode(mel) - encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate)) - device = 'cpu' + encoder_feature = torch.as_tensor(encoder_feature_ctranslate) + device = encoder_feature.device else: # mel + padding to 30s mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, @@ -668,4 +668,4 @@ class PaddedAlignAttWhisper: self._clean_cache() - return new_hypothesis, generation + return new_hypothesis, generation \ No newline at end of file