ensure data is on same device as the model

This commit is contained in:
robcaulk
2023-04-13 12:19:34 +02:00
parent 0afd5a7385
commit dcf9bbdaea
2 changed files with 3 additions and 2 deletions

View File

@@ -45,5 +45,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
device=self.device device=self.device
) )
y = self.model.model(x) y = self.model.model(x)
y = y.cpu()
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]]) pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
return (pred_df, dk.do_predict) return (pred_df, dk.do_predict)

View File

@@ -143,8 +143,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
""" """
data_loader_dictionary = {} data_loader_dictionary = {}
for split in splits: for split in splits:
x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"]) x = self.data_convertor.convert_x(data_dictionary[f"{split}_features"], self.device)
y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"]) y = self.data_convertor.convert_y(data_dictionary[f"{split}_labels"], self.device)
dataset = TensorDataset(*x, *y) dataset = TensorDataset(*x, *y)
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,