mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-02-24 21:30:51 +00:00
pytorch - ruff fixes
This commit is contained in:
@@ -78,7 +78,11 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
|
||||
data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary, splits)
|
||||
n_obs = len(data_dictionary["train_features"])
|
||||
n_epochs = self.n_epochs or self.calc_n_epochs(n_obs=n_obs, batch_size=self.batch_size, n_iters=self.max_iters)
|
||||
n_epochs = self.n_epochs or self.calc_n_epochs(
|
||||
n_obs=n_obs,
|
||||
batch_size=self.batch_size,
|
||||
n_iters=self.max_iters,
|
||||
)
|
||||
for epoch in range(1, n_epochs + 1):
|
||||
for i, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
xb, yb = batch_data
|
||||
@@ -152,7 +156,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
n_epochs = min(n_iters // n_batches, 1)
|
||||
if n_epochs <= 10:
|
||||
logger.warning(
|
||||
f"Setting low n_epochs. {n_epochs} = n_epochs = n_iters // n_batches = {n_iters} // {n_batches}. "
|
||||
f"Setting low n_epochs: {n_epochs}. "
|
||||
f"Please consider increasing `max_iters` hyper-parameter."
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user