fix: handle numpy object_ dtype from ctranslate2 encoder (#337)

This commit is contained in:
Quentin Fuxa
2026-02-19 22:18:00 +01:00
parent 4c7706e2cf
commit b8d9d7d289

View File

@@ -280,13 +280,13 @@ class AlignAtt(AlignAttBase):
if self.device == 'cpu':
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
try:
encoder_feature = torch.as_tensor(
encoder_feature_ctranslate, device=self.device,
)
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
except TypeError:
encoder_feature = torch.as_tensor(
np.array(encoder_feature_ctranslate), device=self.device,
)
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
arr = np.array(encoder_feature_ctranslate)
if arr.dtype == np.object_:
arr = np.array(arr.tolist(), dtype=np.float32)
encoder_feature = torch.as_tensor(arr, device=self.device)
else:
mel_padded = log_mel_spectrogram(
input_segments, n_mels=self.model.dims.n_mels,