types: slightly improved typing

This commit is contained in:
Matthias
2025-04-12 10:05:11 +02:00
parent 1d22377cad
commit 7a51c9d540
3 changed files with 22 additions and 18 deletions

View File

@@ -9,7 +9,6 @@ from abc import ABC
from typing import TypeAlias
from sklearn.base import RegressorMixin
from skopt.space import Dimension # , Integer, Categorical,
from freqtrade.constants import Config
from freqtrade.exchange import timeframe_to_minutes
@@ -17,6 +16,7 @@ from freqtrade.misc import round_dict
from freqtrade.optimize.space import SKDecimal
from freqtrade.strategy import IStrategy
from freqtrade.strategy.parameters import (
DimensionProtocol,
ft_CategoricalDistribution,
ft_IntDistribution,
)
@@ -45,7 +45,7 @@ class IHyperOpt(ABC):
# Assign timeframe to be used in hyperopt
IHyperOpt.timeframe = str(config["timeframe"])
def generate_estimator(self, dimensions: list[Dimension], **kwargs) -> EstimatorType:
def generate_estimator(self, dimensions: list[DimensionProtocol], **kwargs) -> EstimatorType:
"""
Return base_estimator.
Can be any of "TPESampler", "GPSampler", "CmaEsSampler", "NSGAIISampler"
@@ -69,7 +69,7 @@ class IHyperOpt(ABC):
return roi_table
def roi_space(self) -> list[Dimension]:
def roi_space(self) -> list[DimensionProtocol]:
"""
Create a ROI space.
@@ -154,7 +154,7 @@ class IHyperOpt(ABC):
),
]
def stoploss_space(self) -> list[Dimension]:
def stoploss_space(self) -> list[DimensionProtocol]:
"""
Create a stoploss space.
@@ -178,7 +178,7 @@ class IHyperOpt(ABC):
"trailing_only_offset_is_reached": params["trailing_only_offset_is_reached"],
}
def trailing_space(self) -> list[Dimension]:
def trailing_space(self) -> list[DimensionProtocol]:
"""
Create a trailing stoploss space.
@@ -204,7 +204,7 @@ class IHyperOpt(ABC):
ft_CategoricalDistribution("trailing_only_offset_is_reached", [True, False]),
]
def max_open_trades_space(self) -> list[Dimension]:
def max_open_trades_space(self) -> list[DimensionProtocol]:
"""
Create a max open trades space.

View File

@@ -30,6 +30,7 @@ from freqtrade.optimize.hyperopt_loss.hyperopt_loss_interface import IHyperOptLo
from freqtrade.optimize.hyperopt_tools import HyperoptStateContainer, HyperoptTools
from freqtrade.optimize.optimize_reports import generate_strategy_stats
from freqtrade.resolvers.hyperopt_resolver import HyperOptLossResolver
from freqtrade.strategy.parameters import DimensionProtocol
from freqtrade.util.dry_run_wallet import get_dry_run_wallet
@@ -40,7 +41,6 @@ with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
# warnings.filterwarnings("ignore", category=ExperimentalWarning)
import optuna
from skopt.space import Dimension
from freqtrade.optimize.space.decimalspace import SKDecimal
from freqtrade.strategy.parameters import (
@@ -71,14 +71,14 @@ class HyperOptimizer:
"""
def __init__(self, config: Config) -> None:
self.buy_space: list[Dimension] = []
self.sell_space: list[Dimension] = []
self.protection_space: list[Dimension] = []
self.roi_space: list[Dimension] = []
self.stoploss_space: list[Dimension] = []
self.trailing_space: list[Dimension] = []
self.max_open_trades_space: list[Dimension] = []
self.dimensions: list[Dimension] = []
self.buy_space: list[DimensionProtocol] = []
self.sell_space: list[DimensionProtocol] = []
self.protection_space: list[DimensionProtocol] = []
self.roi_space: list[DimensionProtocol] = []
self.stoploss_space: list[DimensionProtocol] = []
self.trailing_space: list[DimensionProtocol] = []
self.max_open_trades_space: list[DimensionProtocol] = []
self.dimensions: list[DimensionProtocol] = []
self.o_dimensions: dict = {}
self.config = config
@@ -146,7 +146,7 @@ class HyperOptimizer:
def _get_params_dict(
self,
dimensions: list[Dimension],
dimensions: list[DimensionProtocol],
raw_params: dict[str, Any],
) -> dict[str, Any]:
# logger.info(f"_get_params_dict: {raw_params}")
@@ -404,7 +404,7 @@ class HyperOptimizer:
"total_profit": total_profit,
}
def convert_dimensions_to_optuna_space(self, s_dimensions: list[Dimension]) -> dict:
def convert_dimensions_to_optuna_space(self, s_dimensions: list[DimensionProtocol]) -> dict:
o_dimensions: dict[str, optuna.distributions.BaseDistribution] = {}
for original_dim in s_dimensions:
if isinstance(original_dim, SKDecimal):

View File

@@ -7,7 +7,7 @@ import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import suppress
from typing import Any, Union
from typing import Any, Protocol, Union
from freqtrade.enums import HyperoptState
from freqtrade.optimize.hyperopt_tools import HyperoptStateContainer
@@ -24,6 +24,10 @@ from freqtrade.exceptions import OperationalException
logger = logging.getLogger(__name__)
class DimensionProtocol(Protocol):
name: str
class ft_CategoricalDistribution(CategoricalDistribution):
def __init__(
self,