mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-12-18 22:01:15 +00:00
ensure data is on same device as the model
This commit is contained in:
@@ -45,5 +45,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
|
|||||||
device=self.device
|
device=self.device
|
||||||
)
|
)
|
||||||
y = self.model.model(x)
|
y = self.model.model(x)
|
||||||
|
y = y.cpu()
|
||||||
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
|
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
|
||||||
return (pred_df, dk.do_predict)
|
return (pred_df, dk.do_predict)
|
||||||
|
|||||||
@@ -143,8 +143,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
"""
|
"""
|
||||||
data_loader_dictionary = {}
|
data_loader_dictionary = {}
|
||||||
for split in splits:
|
for split in splits:
|
||||||
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"])
|
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
|
||||||
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"])
|
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
|
||||||
dataset = TensorDataset(*x, *y)
|
dataset = TensorDataset(*x, *y)
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user