mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-01-20 14:00:38 +00:00
pytorch - set n_steps type as optional
This commit is contained in:
@@ -50,9 +50,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
self.criterion = criterion
|
||||
self.model_meta_data = model_meta_data
|
||||
self.device = device
|
||||
self.n_steps: int = kwargs.get("n_steps", None)
|
||||
self.n_epochs: Optional[int] = kwargs.get("n_epochs", 10)
|
||||
if not self.n_steps and not self.n_epochs:
|
||||
self.n_steps: Optional[int] = kwargs.get("n_steps", None)
|
||||
if self.n_steps is None and not self.n_epochs:
|
||||
raise Exception("Either `n_steps` or `n_epochs` should be set.")
|
||||
|
||||
self.batch_size: int = kwargs.get("batch_size", 64)
|
||||
@@ -79,12 +79,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
|
||||
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
|
||||
n_obs = len(data_dictionary["train_features"])
|
||||
n_epochs = self.n_epochs or self.calc_n_epochs(
|
||||
n_obs=n_obs,
|
||||
batch_size=self.batch_size,
|
||||
n_iters=self.n_steps,
|
||||
)
|
||||
|
||||
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs)
|
||||
batch_counter = 0
|
||||
for _ in range(n_epochs):
|
||||
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
@@ -147,8 +142,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
|
||||
return data_loader_dictionary
|
||||
|
||||
@staticmethod
|
||||
def calc_n_epochs(n_obs: int, batch_size: int, n_iters: int) -> int:
|
||||
def calc_n_epochs(self, n_obs: int) -> int:
|
||||
"""
|
||||
Calculates the number of epochs required to reach the maximum number
|
||||
of iterations specified in the model training parameters.
|
||||
@@ -156,9 +150,9 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
the motivation here is that `n_steps` is easier to optimize and keep stable,
|
||||
across different n_obs - the number of data points.
|
||||
"""
|
||||
|
||||
n_batches = n_obs // batch_size
|
||||
n_epochs = min(n_iters // n_batches, 1)
|
||||
assert isinstance(self.n_steps, int), "Either `n_steps` or `n_epochs` should be set."
|
||||
n_batches = n_obs // self.batch_size
|
||||
n_epochs = min(self.n_steps // n_batches, 1)
|
||||
if n_epochs <= 10:
|
||||
logger.warning(
|
||||
f"Setting low n_epochs: {n_epochs}. "
|
||||
|
||||
Reference in New Issue
Block a user