feat: use floatDistribution for SKDecimal

This commit is contained in:
Matthias
2025-04-12 12:10:48 +02:00
parent 05f19d574a
commit 4fcc9dd587
2 changed files with 21 additions and 51 deletions

View File

@@ -407,18 +407,9 @@ class HyperOptimizer:
def convert_dimensions_to_optuna_space(self, s_dimensions: list[DimensionProtocol]) -> dict:
o_dimensions: dict[str, optuna.distributions.BaseDistribution] = {}
for original_dim in s_dimensions:
if isinstance(original_dim, SKDecimal):
o_dimensions[original_dim.name] = optuna.distributions.FloatDistribution(
original_dim.low_orig,
original_dim.high_orig,
log=False,
step=1 / pow(10, original_dim.decimals),
)
# for preparing to remove old skopt spaces
elif (
isinstance(original_dim, ft_CategoricalDistribution)
or isinstance(original_dim, ft_IntDistribution)
or isinstance(original_dim, ft_FloatDistribution)
if isinstance(
original_dim,
ft_CategoricalDistribution | ft_IntDistribution | ft_FloatDistribution | SKDecimal,
):
o_dimensions[original_dim.name] = original_dim
else:

View File

@@ -1,47 +1,26 @@
import numpy as np
from skopt.space import Integer
from optuna.distributions import FloatDistribution
class SKDecimal(Integer):
class SKDecimal(FloatDistribution):
def __init__(
self,
low,
high,
decimals=3,
prior="uniform",
base=10,
transform=None,
low: float,
high: float,
step: float | None = None,
decimals: int | None = 3,
name=None,
dtype=np.int64,
):
self.decimals = decimals
"""
FloatDistribution with a fixed step size.
"""
if decimals is not None and step is not None:
raise ValueError("You can only set one of decimals or step")
# Convert decimals to step
self.step = step or 1 / 10**decimals
self.name = name
self.pow_dot_one = pow(0.1, self.decimals)
self.pow_ten = pow(10, self.decimals)
_low = int(low * self.pow_ten)
_high = int(high * self.pow_ten)
# trunc to precision to avoid points out of space
self.low_orig = round(_low * self.pow_dot_one, self.decimals)
self.high_orig = round(_high * self.pow_dot_one, self.decimals)
super().__init__(_low, _high, prior, base, transform, name, dtype)
def __repr__(self):
return (
f"Decimal(low={self.low_orig}, high={self.high_orig}, decimals={self.decimals}, "
f"prior='{self.prior}', transform='{self.transform_}')"
super().__init__(
low=low,
high=high,
step=self.step,
)
def __contains__(self, point):
if isinstance(point, list):
point = np.array(point)
return self.low_orig <= point <= self.high_orig
def transform(self, Xt):
return super().transform([int(v * self.pow_ten) for v in Xt])
def inverse_transform(self, Xt):
res = super().inverse_transform(Xt)
# equivalent to [round(x * pow(0.1, self.decimals), self.decimals) for x in res]
return [int(v) / self.pow_ten for v in res]