pytorch - trainer - revert load changes

This commit is contained in:
yinon
2023-08-04 12:52:55 +00:00
parent 777d25192c
commit d17bf6350d

View File

@@ -181,8 +181,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"pytrainer": self
}, path)
def load(self, path: Path, device: Optional[str] = None):
checkpoint = torch.load(path, map_location=device)
def load(self, path: Path):
checkpoint = torch.load(path)
return self.load_from_checkpoint(checkpoint)
def load_from_checkpoint(self, checkpoint: Dict):