simplify get_optimizer

This commit is contained in:
viotemp1
2025-03-26 16:42:09 +02:00
parent 2e06eb0e7b
commit 553dbccec7
3 changed files with 21 additions and 29 deletions

View File

@@ -170,7 +170,7 @@ class Hyperopt:
def get_asked_points( def get_asked_points(
self, n_points: int, dimensions: dict self, n_points: int, dimensions: dict
) -> tuple[list[list[Any]], list[bool]]: ) -> tuple[list[Any], list[bool]]:
""" """
Enforce points returned from `self.opt.ask` have not been already evaluated Enforce points returned from `self.opt.ask` have not been already evaluated
@@ -300,6 +300,8 @@ class Hyperopt:
asked, is_random = self.get_asked_points( asked, is_random = self.get_asked_points(
n_points=current_jobs, dimensions=self.hyperopter.o_dimensions n_points=current_jobs, dimensions=self.hyperopter.o_dimensions
) )
# asked_params = [asked1.params for asked1 in asked]
# logger.info(f"asked iteration {i}: {asked_params}")
f_val = self.run_optimizer_parallel( f_val = self.run_optimizer_parallel(
parallel, [asked1.params for asked1 in asked] parallel, [asked1.params for asked1 in asked]
) )

View File

@@ -45,6 +45,14 @@ logger = logging.getLogger(__name__)
MAX_LOSS = 100000 # just a big enough number to be bad result in loss optimization MAX_LOSS = 100000 # just a big enough number to be bad result in loss optimization
optuna_samplers_dict = {
"TPESampler": optuna.samplers.TPESampler,
"GPSampler": optuna.samplers.GPSampler,
"CmaEsSampler": optuna.samplers.CmaEsSampler,
"NSGAIISampler": optuna.samplers.NSGAIISampler,
"NSGAIIISampler": optuna.samplers.NSGAIIISampler,
"QMCSampler": optuna.samplers.QMCSampler
}
class HyperOptimizer: class HyperOptimizer:
""" """
@@ -390,17 +398,17 @@ class HyperOptimizer:
def convert_dimensions_to_optuna_space(self, s_dimensions: list[Dimension]) -> dict: def convert_dimensions_to_optuna_space(self, s_dimensions: list[Dimension]) -> dict:
o_dimensions = {} o_dimensions = {}
for original_dim in s_dimensions: for original_dim in s_dimensions:
if isinstance(original_dim, Integer): if isinstance(original_dim, SKDecimal):
o_dimensions[original_dim.name] = optuna.distributions.IntDistribution(
original_dim.low, original_dim.high, log=False, step=1
)
elif isinstance(original_dim, SKDecimal):
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution( o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
original_dim.low_orig, original_dim.low_orig,
original_dim.high_orig, original_dim.high_orig,
log=False, log=False,
step=1 / pow(10, original_dim.decimals), step=1 / pow(10, original_dim.decimals),
) )
elif isinstance(original_dim, Integer):
o_dimensions[original_dim.name] = optuna.distributions.IntDistribution(
original_dim.low, original_dim.high, log=False, step=1
)
elif isinstance(original_dim, Real): elif isinstance(original_dim, Real):
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution( o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
original_dim.low, original_dim.low,
@@ -413,7 +421,7 @@ class HyperOptimizer:
) )
else: else:
raise Exception(f"Unknown search space {original_dim} / {type(original_dim)}") raise Exception(f"Unknown search space {original_dim} / {type(original_dim)}")
# logger.info(f"convert_dimensions_to_optuna_space: {s_dimensions} - {o_dimensions}") # logger.info(f"convert_dimensions_to_optuna_space: {s_dimensions} - {o_dimensions}")
return o_dimensions return o_dimensions
def get_optimizer( def get_optimizer(
@@ -429,28 +437,11 @@ class HyperOptimizer:
# restored_sampler = pickle.load(open("sampler.pkl", "rb")) # restored_sampler = pickle.load(open("sampler.pkl", "rb"))
if isinstance(o_sampler, str): if isinstance(o_sampler, str):
if o_sampler not in ( if o_sampler not in optuna_samplers_dict.keys():
"TPESampler",
"GPSampler",
"CmaEsSampler",
"NSGAIISampler",
"NSGAIIISampler",
"QMCSampler",
):
raise OperationalException(f"Optuna Sampler {o_sampler} not supported.") raise OperationalException(f"Optuna Sampler {o_sampler} not supported.")
sampler = optuna_samplers_dict[o_sampler](seed=random_state)
if o_sampler == "TPESampler": else:
sampler = optuna.samplers.TPESampler(seed=random_state) sampler = o_sampler
elif o_sampler == "GPSampler":
sampler = optuna.samplers.GPSampler(seed=random_state)
elif o_sampler == "CmaEsSampler":
sampler = optuna.samplers.CmaEsSampler(seed=random_state)
elif o_sampler == "NSGAIISampler":
sampler = optuna.samplers.NSGAIISampler(seed=random_state)
elif o_sampler == "NSGAIIISampler":
sampler = optuna.samplers.NSGAIIISampler(seed=random_state)
elif o_sampler == "QMCSampler":
sampler = optuna.samplers.QMCSampler(seed=random_state)
logger.info(f"Using optuna sampler {o_sampler}.") logger.info(f"Using optuna sampler {o_sampler}.")
return optuna.create_study(sampler=sampler, direction="minimize") return optuna.create_study(sampler=sampler, direction="minimize")

View File

@@ -7,5 +7,4 @@ scikit-learn==1.6.1
ft-scikit-optimize==0.9.2 ft-scikit-optimize==0.9.2
filelock==3.18.0 filelock==3.18.0
optuna==4.2.1 optuna==4.2.1
optunahub==0.2.0
cmaes==0.11.1 cmaes==0.11.1