mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
simplify get_optimizer
This commit is contained in:
@@ -170,7 +170,7 @@ class Hyperopt:
|
|||||||
|
|
||||||
def get_asked_points(
|
def get_asked_points(
|
||||||
self, n_points: int, dimensions: dict
|
self, n_points: int, dimensions: dict
|
||||||
) -> tuple[list[list[Any]], list[bool]]:
|
) -> tuple[list[Any], list[bool]]:
|
||||||
"""
|
"""
|
||||||
Enforce points returned from `self.opt.ask` have not been already evaluated
|
Enforce points returned from `self.opt.ask` have not been already evaluated
|
||||||
|
|
||||||
@@ -300,6 +300,8 @@ class Hyperopt:
|
|||||||
asked, is_random = self.get_asked_points(
|
asked, is_random = self.get_asked_points(
|
||||||
n_points=current_jobs, dimensions=self.hyperopter.o_dimensions
|
n_points=current_jobs, dimensions=self.hyperopter.o_dimensions
|
||||||
)
|
)
|
||||||
|
# asked_params = [asked1.params for asked1 in asked]
|
||||||
|
# logger.info(f"asked iteration {i}: {asked_params}")
|
||||||
f_val = self.run_optimizer_parallel(
|
f_val = self.run_optimizer_parallel(
|
||||||
parallel, [asked1.params for asked1 in asked]
|
parallel, [asked1.params for asked1 in asked]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -45,6 +45,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MAX_LOSS = 100000 # just a big enough number to be bad result in loss optimization
|
MAX_LOSS = 100000 # just a big enough number to be bad result in loss optimization
|
||||||
|
|
||||||
|
optuna_samplers_dict = {
|
||||||
|
"TPESampler": optuna.samplers.TPESampler,
|
||||||
|
"GPSampler": optuna.samplers.GPSampler,
|
||||||
|
"CmaEsSampler": optuna.samplers.CmaEsSampler,
|
||||||
|
"NSGAIISampler": optuna.samplers.NSGAIISampler,
|
||||||
|
"NSGAIIISampler": optuna.samplers.NSGAIIISampler,
|
||||||
|
"QMCSampler": optuna.samplers.QMCSampler
|
||||||
|
}
|
||||||
|
|
||||||
class HyperOptimizer:
|
class HyperOptimizer:
|
||||||
"""
|
"""
|
||||||
@@ -390,17 +398,17 @@ class HyperOptimizer:
|
|||||||
def convert_dimensions_to_optuna_space(self, s_dimensions: list[Dimension]) -> dict:
|
def convert_dimensions_to_optuna_space(self, s_dimensions: list[Dimension]) -> dict:
|
||||||
o_dimensions = {}
|
o_dimensions = {}
|
||||||
for original_dim in s_dimensions:
|
for original_dim in s_dimensions:
|
||||||
if isinstance(original_dim, Integer):
|
if isinstance(original_dim, SKDecimal):
|
||||||
o_dimensions[original_dim.name] = optuna.distributions.IntDistribution(
|
|
||||||
original_dim.low, original_dim.high, log=False, step=1
|
|
||||||
)
|
|
||||||
elif isinstance(original_dim, SKDecimal):
|
|
||||||
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
|
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
|
||||||
original_dim.low_orig,
|
original_dim.low_orig,
|
||||||
original_dim.high_orig,
|
original_dim.high_orig,
|
||||||
log=False,
|
log=False,
|
||||||
step=1 / pow(10, original_dim.decimals),
|
step=1 / pow(10, original_dim.decimals),
|
||||||
)
|
)
|
||||||
|
elif isinstance(original_dim, Integer):
|
||||||
|
o_dimensions[original_dim.name] = optuna.distributions.IntDistribution(
|
||||||
|
original_dim.low, original_dim.high, log=False, step=1
|
||||||
|
)
|
||||||
elif isinstance(original_dim, Real):
|
elif isinstance(original_dim, Real):
|
||||||
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
|
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
|
||||||
original_dim.low,
|
original_dim.low,
|
||||||
@@ -413,7 +421,7 @@ class HyperOptimizer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown search space {original_dim} / {type(original_dim)}")
|
raise Exception(f"Unknown search space {original_dim} / {type(original_dim)}")
|
||||||
# logger.info(f"convert_dimensions_to_optuna_space: {s_dimensions} - {o_dimensions}")
|
# logger.info(f"convert_dimensions_to_optuna_space: {s_dimensions} - {o_dimensions}")
|
||||||
return o_dimensions
|
return o_dimensions
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
@@ -429,28 +437,11 @@ class HyperOptimizer:
|
|||||||
# restored_sampler = pickle.load(open("sampler.pkl", "rb"))
|
# restored_sampler = pickle.load(open("sampler.pkl", "rb"))
|
||||||
|
|
||||||
if isinstance(o_sampler, str):
|
if isinstance(o_sampler, str):
|
||||||
if o_sampler not in (
|
if o_sampler not in optuna_samplers_dict.keys():
|
||||||
"TPESampler",
|
|
||||||
"GPSampler",
|
|
||||||
"CmaEsSampler",
|
|
||||||
"NSGAIISampler",
|
|
||||||
"NSGAIIISampler",
|
|
||||||
"QMCSampler",
|
|
||||||
):
|
|
||||||
raise OperationalException(f"Optuna Sampler {o_sampler} not supported.")
|
raise OperationalException(f"Optuna Sampler {o_sampler} not supported.")
|
||||||
|
sampler = optuna_samplers_dict[o_sampler](seed=random_state)
|
||||||
if o_sampler == "TPESampler":
|
else:
|
||||||
sampler = optuna.samplers.TPESampler(seed=random_state)
|
sampler = o_sampler
|
||||||
elif o_sampler == "GPSampler":
|
|
||||||
sampler = optuna.samplers.GPSampler(seed=random_state)
|
|
||||||
elif o_sampler == "CmaEsSampler":
|
|
||||||
sampler = optuna.samplers.CmaEsSampler(seed=random_state)
|
|
||||||
elif o_sampler == "NSGAIISampler":
|
|
||||||
sampler = optuna.samplers.NSGAIISampler(seed=random_state)
|
|
||||||
elif o_sampler == "NSGAIIISampler":
|
|
||||||
sampler = optuna.samplers.NSGAIIISampler(seed=random_state)
|
|
||||||
elif o_sampler == "QMCSampler":
|
|
||||||
sampler = optuna.samplers.QMCSampler(seed=random_state)
|
|
||||||
|
|
||||||
logger.info(f"Using optuna sampler {o_sampler}.")
|
logger.info(f"Using optuna sampler {o_sampler}.")
|
||||||
return optuna.create_study(sampler=sampler, direction="minimize")
|
return optuna.create_study(sampler=sampler, direction="minimize")
|
||||||
|
|||||||
@@ -7,5 +7,4 @@ scikit-learn==1.6.1
|
|||||||
ft-scikit-optimize==0.9.2
|
ft-scikit-optimize==0.9.2
|
||||||
filelock==3.18.0
|
filelock==3.18.0
|
||||||
optuna==4.2.1
|
optuna==4.2.1
|
||||||
optunahub==0.2.0
|
|
||||||
cmaes==0.11.1
|
cmaes==0.11.1
|
||||||
Reference in New Issue
Block a user