From 9fb0ce664c76c02bdc15d351604385ad25bc43c9 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Thu, 13 Jul 2023 21:32:46 +0300 Subject: [PATCH] pytorch - ruff fixes --- freqtrade/freqai/torch/PyTorchModelTrainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index efdf3ed5a..e6691f3db 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -78,7 +78,11 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits) n_obs = len(data_dictionary["train_features"]) - n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters) + n_epochs = self.n_epochs or self.calc_n_epochs( + n_obs=n_obs, + batch_size=self.batch_size, + n_iters=self.max_iters, + ) for epoch in range(1, n_epochs + 1): for i, batch_data in enumerate(data_loaders_dictionary["train"]): xb, yb = batch_data @@ -152,7 +156,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): n_epochs = min(n_iters // n_batches, 1) if n_epochs <= 10: logger.warning( - f"Setting low n_epochs. {n_epochs} = n_epochs = n_iters // n_batches = {n_iters} // {n_batches}. " + f"Setting low n_epochs: {n_epochs}. " f"Please consider increasing `max_iters` hyper-parameter." )