resolve conflict, ensure gpu works with transformer

This commit is contained in:
robcaulk
2023-05-19 14:39:16 +00:00
104 changed files with 3059 additions and 770 deletions

View File

@@ -45,6 +45,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
"""
Filter the prediction features data and predict with it.
:param dk: dk: The datakitchen object
:param unfiltered_df: Full dataframe for the current backtest period.
:return:
:pred_df: dataframe containing the predictions
@@ -74,11 +75,14 @@ class BasePyTorchClassifier(BasePyTorchModel):
dk.data_dictionary["prediction_features"],
device=self.device
)
self.model.model.eval()
logits = self.model.model(x)
probs = F.softmax(logits, dim=-1)
predicted_classes = torch.argmax(probs, dim=-1)
predicted_classes_str = self.decode_class_names(predicted_classes)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
# used .tolist to convert probs into an iterable, in this way Tensors
# are automatically moved to the CPU first if necessary.
pred_df_prob = DataFrame(probs.detach().tolist(), columns=class_names)
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict)

View File

@@ -27,6 +27,7 @@ class BasePyTorchModel(IFreqaiModel, ABC):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
test_size = self.freqai_info.get('data_split_parameters', {}).get('test_size')
self.splits = ["train", "test"] if test_size != 0 else ["train"]
self.window_size = self.freqai_info.get("conv_width", 1)
def train(
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs

View File

@@ -44,8 +44,8 @@ class BasePyTorchRegressor(BasePyTorchModel):
dk.data_dictionary["prediction_features"],
device=self.device
)
self.model.model.eval()
y = self.model.model(x)
y = y.cpu()
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
pred_df = DataFrame(y.detach().tolist(), columns=[dk.label_list[0]])
pred_df = dk.denormalize_labels_from_metadata(pred_df)
return (pred_df, dk.do_predict)