pytorch - set n_steps type as optional

This commit is contained in:
yinon
2023-08-04 14:33:59 +00:00
parent 9f69a45afd
commit 23d2bad2a0

View File

@@ -50,9 +50,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.n_steps: int = kwargs.get("n_steps", None)
self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10)
if not self.n_steps and not self.n_epochs:
self.n_steps: Optional[int] = kwargs.get("n_steps", None)
if self.n_steps is None and not self.n_epochs:
raise Exception("Either `n_steps` or `n_epochs` should be set.")
self.batch_size: int = kwargs.get("batch_size", 64)
@@ -79,12 +79,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
n_obs = len(data_dictionary["train_features"])
n_epochs = self.n_epochs or self.calc_n_epochs(
n_obs=n_obs,
batch_size=self.batch_size,
n_iters=self.n_steps,
)
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs)
batch_counter = 0
for _ in range(n_epochs):
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
@@ -147,8 +142,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
return data_loader_dictionary
@staticmethod
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
def calc_n_epochs(self, n_obs: int) -> int:
"""
Calculates the number of epochs required to reach the maximum number
of iterations specified in the model training parameters.
@@ -156,9 +150,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
the motivation here is that `n_steps` is easier to optimize and keep stable,
across different n_obs - the number of data points.
"""
n_batches = n_obs // batch_size
n_epochs = min(n_iters // n_batches, 1)
assert isinstance(self.n_steps, int), "Either `n_steps` or `n_epochs` should be set."
n_batches = n_obs // self.batch_size
n_epochs = min(self.n_steps // n_batches, 1)
if n_epochs <= 10:
logger.warning(
f"Setting low n_epochs: {n_epochs}. "