diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 52e6d5138..6a4b128e3 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -1,4 +1,5 @@ import logging +import math from pathlib import Path from typing import Any, Dict, Optional @@ -148,10 +149,13 @@ class PyTorchModelTrainer: """ Calculates the number of epochs required to reach the maximum number of iterations specified in the model training parameters. + + the motivation here is that `max_iters` is easier to optimize and keep stable, + across different n_obs - the number of data points. """ - n_batches = n_obs // batch_size - epochs = n_iters // n_batches + n_batches = math.ceil(n_obs // batch_size) + epochs = math.ceil(n_iters // n_batches) return epochs def save(self, path: Path):