pytorch - trainer - clean code

This commit is contained in:
Yinon Polak
2023-07-15 14:43:05 +03:00
parent 77f1584713
commit d61f512e20

View File

@@ -86,7 +86,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
)
batch_counter = 0
for epoch in range(n_epochs):
for _ in range(n_epochs):
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
@@ -171,7 +171,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
"""
- Saving any nn.Module state_dict
- Saving model_meta_data, this dict should contain any additional data that the
user needs to store. e.g class_names for classification models.
user needs to store. e.g. class_names for classification models.
"""
torch.save({