mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
fix: handle numpy object_ dtype from ctranslate2 encoder (#337)
This commit is contained in:
@@ -280,13 +280,13 @@ class AlignAtt(AlignAttBase):
|
||||
if self.device == 'cpu':
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
try:
|
||||
encoder_feature = torch.as_tensor(
|
||||
encoder_feature_ctranslate, device=self.device,
|
||||
)
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError:
|
||||
encoder_feature = torch.as_tensor(
|
||||
np.array(encoder_feature_ctranslate), device=self.device,
|
||||
)
|
||||
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
|
||||
arr = np.array(encoder_feature_ctranslate)
|
||||
if arr.dtype == np.object_:
|
||||
arr = np.array(arr.tolist(), dtype=np.float32)
|
||||
encoder_feature = torch.as_tensor(arr, device=self.device)
|
||||
else:
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments, n_mels=self.model.dims.n_mels,
|
||||
|
||||
Reference in New Issue
Block a user