torch.Tensor to torch.as_tensor

This commit is contained in:
Quentin Fuxa
2025-09-03 23:01:00 +02:00
parent e0a5cbf0e7
commit f3ad4e39e4

View File

@@ -399,17 +399,17 @@ class PaddedAlignAttWhisper:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
encoder_feature = torch.as_tensor(mlx_encoder_feature)
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu'
device = encoder_feature.device #'cpu' is apple silicon
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
device = 'cpu'
encoder_feature = torch.as_tensor(encoder_feature_ctranslate)
device = encoder_feature.device
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
@@ -668,4 +668,4 @@ class PaddedAlignAttWhisper:
self._clean_cache()
return new_hypothesis, generation
return new_hypothesis, generation