pytorch - mypy fixes

This commit is contained in:
Yinon Polak
2023-07-13 21:36:14 +03:00
parent 9fb0ce664c
commit ffcba45b1b

View File

@@ -176,7 +176,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"pytrainer": self
}, path)
def load(self, path: Path, device: str = None):
def load(self, path: Path, device: Optional[str] = None):
checkpoint = torch.load(path, map_location=device)
return self.load_from_checkpoint(checkpoint)