mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-03-02 00:03:23 +00:00
pytorch - trainer - bugfix step tensorboard step usage
This commit is contained in:
@@ -59,6 +59,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
self.data_convertor = data_convertor
|
||||
self.window_size: int = window_size
|
||||
self.tb_logger = tb_logger
|
||||
self.test_batch_counter = 0
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
|
||||
"""
|
||||
@@ -83,8 +84,10 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
batch_size=self.batch_size,
|
||||
n_iters=self.max_iters,
|
||||
)
|
||||
for epoch in range(1, n_epochs + 1):
|
||||
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
|
||||
batch_counter = 0
|
||||
for epoch in range(n_epochs):
|
||||
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
xb, yb = batch_data
|
||||
xb = xb.to(self.device)
|
||||
yb = yb.to(self.device)
|
||||
@@ -94,7 +97,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.tb_logger.log_scalar("train_loss", loss.item(), i)
|
||||
self.tb_logger.log_scalar("train_loss", loss.item(), batch_counter)
|
||||
batch_counter += 1
|
||||
|
||||
# evaluation
|
||||
if "test" in splits:
|
||||
@@ -107,14 +111,15 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
split: str,
|
||||
) -> None:
|
||||
self.model.eval()
|
||||
for i, batch_data in enumerate(data_loader_dictionary[split]):
|
||||
for _, batch_data in enumerate(data_loader_dictionary[split]):
|
||||
xb, yb = batch_data
|
||||
xb.to(self.device)
|
||||
yb.to(self.device)
|
||||
|
||||
yb_pred = self.model(xb)
|
||||
loss = self.criterion(yb_pred, yb)
|
||||
self.tb_logger.log_scalar(f"{split}_loss", loss.item(), i)
|
||||
self.tb_logger.log_scalar(f"{split}_loss", loss.item(), self.test_batch_counter)
|
||||
self.test_batch_counter += 1
|
||||
|
||||
self.model.train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user