mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
fix duplicate params in same batch also
This commit is contained in:
@@ -15,7 +15,7 @@ from typing import Any
|
||||
|
||||
import rapidjson
|
||||
from joblib import Parallel, cpu_count
|
||||
from optuna.trial import Trial, TrialState
|
||||
from optuna.trial import FrozenTrial, Trial, TrialState
|
||||
|
||||
from freqtrade.constants import FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config
|
||||
from freqtrade.enums import HyperoptState
|
||||
@@ -171,15 +171,19 @@ class Hyperopt:
|
||||
asked.append(self.opt.ask(dimensions))
|
||||
return asked
|
||||
|
||||
def duplicate_optuna_asked_points(self, trial: Trial) -> bool:
|
||||
def duplicate_optuna_asked_points(self, trial: Trial, asked_trials: list[FrozenTrial]) -> bool:
|
||||
asked_trials_no_dups: list[FrozenTrial] = []
|
||||
trials_to_consider = trial.study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
|
||||
# Check whether we already evaluated the sampled `params`.
|
||||
for t in reversed(trials_to_consider):
|
||||
if trial.params == t.params:
|
||||
# logger.warning(
|
||||
# f"duplicate trial: Trial {trial.number} has same params as {t.number}"
|
||||
# )
|
||||
return True
|
||||
# Check whether same`params` in one batch (asked_trials). Autosampler is doing this.
|
||||
for t in asked_trials:
|
||||
if t.params not in asked_trials_no_dups:
|
||||
asked_trials_no_dups.append(t)
|
||||
if len(asked_trials_no_dups) != len(asked_trials):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_asked_points(self, n_points: int, dimensions: dict) -> tuple[list[Any], list[bool]]:
|
||||
@@ -189,26 +193,26 @@ class Hyperopt:
|
||||
Steps:
|
||||
1. Try to get points using `self.opt.ask` first
|
||||
2. Discard the points that have already been evaluated
|
||||
3. Retry using `self.opt.ask` up to 5 times
|
||||
3. Retry using `self.opt.ask` up to `n_points` times
|
||||
"""
|
||||
asked_non_tried: list[list[Any]] = []
|
||||
asked_duplicates: list[Trial] = []
|
||||
asked_non_tried: list[FrozenTrial] = []
|
||||
optuna_asked_trials = self.get_optuna_asked_points(n_points=n_points, dimensions=dimensions)
|
||||
asked_non_tried += [
|
||||
x for x in optuna_asked_trials if not self.duplicate_optuna_asked_points(x)
|
||||
x
|
||||
for x in optuna_asked_trials
|
||||
if not self.duplicate_optuna_asked_points(x, optuna_asked_trials)
|
||||
]
|
||||
i = 0
|
||||
while i < 5 and len(asked_non_tried) < n_points:
|
||||
while i < 2 * n_points and len(asked_non_tried) < n_points:
|
||||
asked_new = self.get_optuna_asked_points(n_points=1, dimensions=dimensions)[0]
|
||||
if not self.duplicate_optuna_asked_points(asked_new):
|
||||
if not self.duplicate_optuna_asked_points(asked_new, asked_non_tried):
|
||||
asked_non_tried.append(asked_new)
|
||||
else:
|
||||
asked_duplicates.append(asked_new)
|
||||
i += 1
|
||||
if len(asked_duplicates) > 0 and len(asked_non_tried) < n_points:
|
||||
for asked_duplicate in asked_duplicates:
|
||||
logger.warning(f"duplicate params for Epoch {asked_duplicate.number}")
|
||||
self.count_skipped_epochs += len(asked_duplicates)
|
||||
if len(asked_non_tried) < n_points:
|
||||
logger.warning(
|
||||
"duplicate params detected. Please check if search space is not too small!"
|
||||
)
|
||||
self.count_skipped_epochs += n_points - len(asked_non_tried)
|
||||
|
||||
return asked_non_tried, [False for _ in range(len(asked_non_tried))]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user