pytorch - trainer - bugfix step tensorboard step usage

This commit is contained in:
Yinon Polak
2023-07-15 14:37:44 +03:00
parent ffcba45b1b
commit 77f1584713

View File

@@ -59,6 +59,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.data_convertor = data_convertor
self.window_size: int = window_size
self.tb_logger = tb_logger
self.test_batch_counter = 0
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
"""
@@ -83,8 +84,10 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
batch_size=self.batch_size,
n_iters=self.max_iters,
)
for epoch in range(1, n_epochs + 1):
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
batch_counter = 0
for epoch in range(n_epochs):
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
xb, yb = batch_data
xb = xb.to(self.device)
yb = yb.to(self.device)
@@ -94,7 +97,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
self.tb_logger.log_scalar("train_loss", loss.item(), i)
self.tb_logger.log_scalar("train_loss", loss.item(), batch_counter)
batch_counter += 1
# evaluation
if "test" in splits:
@@ -107,14 +111,15 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
split: str,
) -> None:
self.model.eval()
for i, batch_data in enumerate(data_loader_dictionary[split]):
for _, batch_data in enumerate(data_loader_dictionary[split]):
xb, yb = batch_data
xb.to(self.device)
yb.to(self.device)
yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb)
self.tb_logger.log_scalar(f"{split}_loss", loss.item(), i)
self.tb_logger.log_scalar(f"{split}_loss", loss.item(), self.test_batch_counter)
self.test_batch_counter += 1
self.model.train()