diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index e6638d4fd..1692b4acf 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -59,6 +59,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): self.data_convertor = data_convertor self.window_size: int = window_size self.tb_logger = tb_logger + self.test_batch_counter = 0 def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]): """ @@ -83,8 +84,10 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): batch_size=self.batch_size, n_iters=self.max_iters, ) - for epoch in range(1, n_epochs + 1): - for i, batch_data in enumerate(data_loaders_dictionary["train"]): + + batch_counter = 0 + for epoch in range(n_epochs): + for _, batch_data in enumerate(data_loaders_dictionary["train"]): xb, yb = batch_data xb = xb.to(self.device) yb = yb.to(self.device) @@ -94,7 +97,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): self.optimizer.zero_grad(set_to_none=True) loss.backward() self.optimizer.step() - self.tb_logger.log_scalar("train_loss", loss.item(), i) + self.tb_logger.log_scalar("train_loss", loss.item(), batch_counter) + batch_counter += 1 # evaluation if "test" in splits: @@ -107,14 +111,15 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): split: str, ) -> None: self.model.eval() - for i, batch_data in enumerate(data_loader_dictionary[split]): + for _, batch_data in enumerate(data_loader_dictionary[split]): xb, yb = batch_data xb.to(self.device) yb.to(self.device) yb_pred = self.model(xb) loss = self.criterion(yb_pred, yb) - self.tb_logger.log_scalar(f"{split}_loss", loss.item(), i) + self.tb_logger.log_scalar(f"{split}_loss", loss.item(), self.test_batch_counter) + self.test_batch_counter += 1 self.model.train()