pytorch - ruff fixes

This commit is contained in:
Yinon Polak
2023-07-13 21:32:46 +03:00
parent 5734358d91
commit 9fb0ce664c

View File

@@ -78,7 +78,11 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
n_obs = len(data_dictionary["train_features"])
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
n_epochs = self.n_epochs or self.calc_n_epochs(
n_obs=n_obs,
batch_size=self.batch_size,
n_iters=self.max_iters,
)
for epoch in range(1, n_epochs + 1):
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
@@ -152,7 +156,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
n_epochs = min(n_iters // n_batches, 1)
if n_epochs <= 10:
logger.warning(
f"Setting low n_epochs. {n_epochs} = n_epochs = n_iters // n_batches = {n_iters} // {n_batches}. "
f"Setting low n_epochs: {n_epochs}. "
f"Please consider increasing `max_iters` hyper-parameter."
)