pytorch - trainer - add assertion that either n_epochs or max_iters is been set.

This commit is contained in:
Yinon Polak
2023-07-13 20:59:33 +03:00
parent 7d28dad209
commit 5734358d91

View File

@@ -39,9 +39,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
state_dict and model_meta_data saved by self.save() method. state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional). :param model_meta_data: Additional metadata about the model (optional).
:param data_convertor: convertor from pd.DataFrame to torch.tensor. :param data_convertor: convertor from pd.DataFrame to torch.tensor.
:param max_iters: The number of training iterations to run. :param max_iters: used to calculate n_epochs. The number of training iterations to run.
iteration here refers to the number of times optimizer.step() is called, iteration here refers to the number of times optimizer.step() is called.
used to calculate n_epochs. ignored if n_epochs is set. ignored if n_epochs is set.
:param n_epochs: The maximum number batches to use for evaluation. :param n_epochs: The maximum number batches to use for evaluation.
:param batch_size: The size of the batches to use during training. :param batch_size: The size of the batches to use during training.
""" """
@@ -52,6 +52,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.device = device self.device = device
self.max_iters: int = kwargs.get("max_iters", 100) self.max_iters: int = kwargs.get("max_iters", 100)
self.n_epochs: Optional[int] = kwargs.get("n_epochs", None) self.n_epochs: Optional[int] = kwargs.get("n_epochs", None)
if not self.max_iters and not self.n_epochs:
raise Exception("Either `max_iters` or `n_epochs` should be set.")
self.batch_size: int = kwargs.get("batch_size", 64) self.batch_size: int = kwargs.get("batch_size", 64)
self.data_convertor = data_convertor self.data_convertor = data_convertor
self.window_size: int = window_size self.window_size: int = window_size
@@ -75,8 +78,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits) data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
n_obs = len(data_dictionary["train_features"]) n_obs = len(data_dictionary["train_features"])
epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters) n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
for epoch in range(1, epochs + 1): for epoch in range(1, n_epochs + 1):
for i, batch_data in enumerate(data_loaders_dictionary["train"]): for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data xb, yb = batch_data
xb = xb.to(self.device) xb = xb.to(self.device)
@@ -146,14 +149,14 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
""" """
n_batches = n_obs // batch_size n_batches = n_obs // batch_size
epochs = n_iters // n_batches n_epochs = min(n_iters // n_batches, 1)
if epochs <= 10: if n_epochs <= 10:
logger.warning("User set `max_iters` in such a way that the trainer will only perform " logger.warning(
f" {epochs} epochs. Please consider increasing this value accordingly") f"Setting low n_epochs. {n_epochs} = n_epochs = n_iters // n_batches = {n_iters} // {n_batches}. "
if epochs <= 1: f"Please consider increasing `max_iters` hyper-parameter."
logger.warning("Epochs set to 1. Please review your `max_iters` value") )
epochs = 1
return epochs return n_epochs
def save(self, path: Path): def save(self, path: Path):
""" """