From 23d2bad2a08e09fdd8ec1f02f1af318a05e510a9 Mon Sep 17 00:00:00 2001 From: yinon Date: Fri, 4 Aug 2023 14:33:59 +0000 Subject: [PATCH] pytorch - set n_steps type as optional --- freqtrade/freqai/torch/PyTorchModelTrainer.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index 44f7dec4e..371a953e7 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -50,9 +50,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): self.criterion = criterion self.model_meta_data = model_meta_data self.device = device - self.n_steps: int = kwargs.get("n_steps", None) self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10) - if not self.n_steps and not self.n_epochs: + self.n_steps: Optional[int] = kwargs.get("n_steps", None) + if self.n_steps is None and not self.n_epochs: raise Exception("Either `n_steps` or `n_epochs` should be set.") self.batch_size: int = kwargs.get("batch_size", 64) @@ -79,12 +79,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits) n_obs = len(data_dictionary["train_features"]) - n_epochs = self.n_epochs or self.calc_n_epochs( - n_obs=n_obs, - batch_size=self.batch_size, - n_iters=self.n_steps, - ) - + n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs) batch_counter = 0 for _ in range(n_epochs): for _, batch_data in enumerate(data_loaders_dictionary["train"]): @@ -147,8 +142,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): return data_loader_dictionary - @staticmethod - def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int: + def calc_n_epochs(self, n_obs: int) -> int: """ Calculates the number of epochs required to reach the maximum number of iterations specified in the model training parameters. @@ -156,9 +150,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): the motivation here is that `n_steps` is easier to optimize and keep stable, across different n_obs - the number of data points. """ - - n_batches = n_obs // batch_size - n_epochs = min(n_iters // n_batches, 1) + assert isinstance(self.n_steps, int), "Either `n_steps` or `n_epochs` should be set." + n_batches = n_obs // self.batch_size + n_epochs = min(self.n_steps // n_batches, 1) if n_epochs <= 10: logger.warning( f"Setting low n_epochs: {n_epochs}. "