pytorch - bugfix - explicitly assign tensor to var as .to() is not inplace operation

This commit is contained in:
yinon
2023-07-13 15:37:50 +00:00
parent 3cf419cbcd
commit 9cb45a3810

View File

@@ -83,8 +83,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb.to(self.device)
yb.to(self.device)
xb = xb.to(self.device)
yb = yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)