add up to 5 retries for ask in case of duplicate params

This commit is contained in:
viotemp1
2025-05-28 09:35:20 +02:00
parent b51c937e87
commit 53383f3184

View File

@@ -92,6 +92,7 @@ class Hyperopt:
self.print_json = self.config.get("print_json", False)
self.hyperopter = HyperOptimizer(self.config, self.data_pickle_file)
self.count_skipped_epochs = 0
@staticmethod
def get_lock_filename(config: Config) -> str:
@@ -170,35 +171,46 @@ class Hyperopt:
asked.append(self.opt.ask(dimensions))
return asked
def check_optuna_asked_points(self, trial: Trial) -> bool:
def duplicate_optuna_asked_points(self, trial: Trial) -> bool:
trials_to_consider = trial.study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
# Check whether we already evaluated the sampled `params`.
for t in reversed(trials_to_consider):
if trial.params == t.params:
logger.warning(
f"duplicate trial: Trial {trial.number} has same params as {t.number}"
)
# logger.warning(
# f"duplicate trial: Trial {trial.number} has same params as {t.number}"
# )
return True
return False
def get_asked_points(self, n_points: int, dimensions: dict) -> tuple[list[Any], list[bool]]:
"""
TBD: need to change
Enforce points returned from `self.opt.ask` have not been already evaluated
Steps:
1. Try to get points using `self.opt.ask` first
2. Discard the points that have already been evaluated
3. Retry using `self.opt.ask` up to 3 times
4. If still some points are missing in respect to `n_points`, random sample some points
5. Repeat until at least `n_points` points in the `asked_non_tried` list
6. Return a list with length truncated at `n_points`
3. Retry using `self.opt.ask` up to 5 times
"""
asked_non_tried: list[list[Any]] = []
asked_duplicates: list[Trial] = []
optuna_asked_trials = self.get_optuna_asked_points(n_points=n_points, dimensions=dimensions)
asked_non_tried += [x for x in optuna_asked_trials if not self.check_optuna_asked_points(x)]
asked_non_tried += [
x for x in optuna_asked_trials if not self.duplicate_optuna_asked_points(x)
]
i = 0
while i < 5 and len(asked_non_tried) < n_points:
asked_new = self.get_optuna_asked_points(n_points=1, dimensions=dimensions)[0]
if not self.duplicate_optuna_asked_points(asked_new):
asked_non_tried.append(asked_new)
else:
asked_duplicates.append(asked_new)
i += 1
if len(asked_duplicates) > 0 and len(asked_non_tried) < n_points:
for asked_duplicate in asked_duplicates:
logger.warning(f"duplicate params for Epoch {asked_duplicate.number}")
self.count_skipped_epochs += len(asked_duplicates)
return asked_non_tried, [False for _ in range(n_points)]
return asked_non_tried, [False for _ in range(len(asked_non_tried))]
def evaluate_result(self, val: dict[str, Any], current: int, is_random: bool):
"""
@@ -284,6 +296,7 @@ class Hyperopt:
parallel,
[asked1.params for asked1 in asked],
)
f_val_loss = [v["loss"] for v in f_val]
for o_ask, v in zip(asked, f_val_loss, strict=False):
self.opt.tell(o_ask, v)
@@ -307,6 +320,12 @@ class Hyperopt:
except KeyboardInterrupt:
print("User interrupted..")
if self.count_skipped_epochs > 0:
logger.info(
f"{self.count_skipped_epochs} {plural(self.count_skipped_epochs, 'epoch')} "
f"skipped due to duplicate parameters."
)
logger.info(
f"{self.num_epochs_saved} {plural(self.num_epochs_saved, 'epoch')} "
f"saved to '{self.results_file}'."