mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-01-31 11:20:24 +00:00
Merge pull request #8596 from autoscatto/bugfix/tensor-to-numpy
Bugfix/tensor to numpy
This commit is contained in:
@@ -45,6 +45,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||
"""
|
||||
Filter the prediction features data and predict with it.
|
||||
:param dk: dk: The datakitchen object
|
||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||
:return:
|
||||
:pred_df: dataframe containing the predictions
|
||||
@@ -78,7 +79,9 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
predicted_classes = torch.argmax(probs, dim=-1)
|
||||
predicted_classes_str = self.decode_class_names(predicted_classes)
|
||||
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
|
||||
# used .tolist to convert probs into an iterable, in this way Tensors
|
||||
# are automatically moved to the CPU first if necessary.
|
||||
pred_df_prob = DataFrame(probs.detach().tolist(), columns=class_names)
|
||||
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
|
||||
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
||||
return (pred_df, dk.do_predict)
|
||||
|
||||
@@ -45,6 +45,5 @@ class BasePyTorchRegressor(BasePyTorchModel):
|
||||
device=self.device
|
||||
)
|
||||
y = self.model.model(x)
|
||||
y = y.cpu()
|
||||
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
|
||||
pred_df = DataFrame(y.detach().tolist(), columns=[dk.label_list[0]])
|
||||
return (pred_df, dk.do_predict)
|
||||
|
||||
Reference in New Issue
Block a user