pytorch - trainer - set default usage of n_epochs instead of max_iters

This commit is contained in:
yinon
2023-08-04 12:50:01 +00:00
parent 8ebfb731d8
commit 836d7b885a

View File

@@ -50,8 +50,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.max_iters: int = kwargs.get("max_iters", 100)
self.n_epochs: Optional[int] = kwargs.get("n_epochs", None)
self.max_iters: int = kwargs.get("max_iters", None)
self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10)
if not self.max_iters and not self.n_epochs:
raise Exception("Either `max_iters` or `n_epochs` should be set.")