ensure data is on same device as the model

This commit is contained in:
robcaulk
2023-04-13 12:19:34 +02:00
parent 0afd5a7385
commit dcf9bbdaea
2 changed files with 3 additions and 2 deletions

View File

@@ -45,5 +45,6 @@ class BasePyTorchRegressor(BasePyTorchModel):
device=self.device
)
y = self.model.model(x)
y = y.cpu()
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
return (pred_df, dk.do_predict)