From dcf9bbdaea6d70187a395b0b1adb81ddbe137fdf Mon Sep 17 00:00:00 2001 From: robcaulk Date: Thu, 13 Apr 2023 12:19:34 +0200 Subject: [PATCH] ensure data is on same device as the model --- freqtrade/freqai/base_models/BasePyTorchRegressor.py | 1 + freqtrade/freqai/torch/PyTorchModelTrainer.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/freqtrade/freqai/base_models/BasePyTorchRegressor.py b/freqtrade/freqai/base_models/BasePyTorchRegressor.py index b9c5fa685..ea6fabe49 100644 --- a/freqtrade/freqai/base_models/BasePyTorchRegressor.py +++ b/freqtrade/freqai/base_models/BasePyTorchRegressor.py @@ -45,5 +45,6 @@ class BasePyTorchRegressor(BasePyTorchModel): device=self.device ) y = self.model.model(x) + y = y.cpu() pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]]) return (pred_df, dk.do_predict) diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index 9c1a1cb6e..8277ba937 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -143,8 +143,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): """ data_loader_dictionary = {} for split in splits: - x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"]) - y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"]) + x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device) + y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device) dataset = TensorDataset(*x, *y) data_loader = DataLoader( dataset,