mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-02-13 01:30:35 +00:00
Merge pull request #10173 from freqtrade/fix/mutable_defaults
Fix mutable defaults, enable bugbear ruff rule also for freqAI code
This commit is contained in:
@@ -46,19 +46,20 @@ class BaseEnvironment(gym.Env):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: DataFrame = DataFrame(),
|
||||
prices: DataFrame = DataFrame(),
|
||||
reward_kwargs: dict = {},
|
||||
*,
|
||||
df: DataFrame,
|
||||
prices: DataFrame,
|
||||
reward_kwargs: dict,
|
||||
window_size=10,
|
||||
starting_point=True,
|
||||
id: str = "baseenv-1", # noqa: A002
|
||||
seed: int = 1,
|
||||
config: dict = {},
|
||||
config: dict,
|
||||
live: bool = False,
|
||||
fee: float = 0.0015,
|
||||
can_short: bool = False,
|
||||
pair: str = "",
|
||||
df_raw: DataFrame = DataFrame(),
|
||||
df_raw: DataFrame,
|
||||
):
|
||||
"""
|
||||
Initializes the training/eval environment.
|
||||
|
||||
@@ -488,7 +488,7 @@ def make_env(
|
||||
seed: int,
|
||||
train_df: DataFrame,
|
||||
price: DataFrame,
|
||||
env_info: dict[str, Any] = {},
|
||||
env_info: dict[str, Any],
|
||||
) -> Callable:
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
@@ -214,7 +214,7 @@ class FreqaiDataKitchen:
|
||||
self,
|
||||
unfiltered_df: DataFrame,
|
||||
training_feature_list: list,
|
||||
label_list: list = list(),
|
||||
label_list: list | None = None,
|
||||
training_filter: bool = True,
|
||||
) -> tuple[DataFrame, DataFrame]:
|
||||
"""
|
||||
@@ -244,7 +244,7 @@ class FreqaiDataKitchen:
|
||||
# we don't care about total row number (total no. datapoints) in training, we only care
|
||||
# about removing any row with NaNs
|
||||
# if labels has multiple columns (user wants to train multiple modelEs), we detect here
|
||||
labels = unfiltered_df.filter(label_list, axis=1)
|
||||
labels = unfiltered_df.filter(label_list or [], axis=1)
|
||||
drop_index_labels = pd.isnull(labels).any(axis=1)
|
||||
drop_index_labels = (
|
||||
drop_index_labels.replace(True, 1).replace(False, 0).infer_objects(copy=False)
|
||||
@@ -654,8 +654,8 @@ class FreqaiDataKitchen:
|
||||
pair: str,
|
||||
tf: str,
|
||||
strategy: IStrategy,
|
||||
corr_dataframes: dict = {},
|
||||
base_dataframes: dict = {},
|
||||
corr_dataframes: dict,
|
||||
base_dataframes: dict,
|
||||
is_corr_pairs: bool = False,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
@@ -773,10 +773,10 @@ class FreqaiDataKitchen:
|
||||
def use_strategy_to_populate_indicators( # noqa: C901
|
||||
self,
|
||||
strategy: IStrategy,
|
||||
corr_dataframes: dict = {},
|
||||
base_dataframes: dict = {},
|
||||
corr_dataframes: dict[str, DataFrame] | None = None,
|
||||
base_dataframes: dict[str, dict[str, DataFrame]] | None = None,
|
||||
pair: str = "",
|
||||
prediction_dataframe: DataFrame = pd.DataFrame(),
|
||||
prediction_dataframe: DataFrame | None = None,
|
||||
do_corr_pairs: bool = True,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
@@ -793,6 +793,10 @@ class FreqaiDataKitchen:
|
||||
:return:
|
||||
dataframe: DataFrame = dataframe containing populated indicators
|
||||
"""
|
||||
if not corr_dataframes:
|
||||
corr_dataframes = {}
|
||||
if not base_dataframes:
|
||||
base_dataframes = {}
|
||||
|
||||
# check if the user is using the deprecated populate_any_indicators function
|
||||
new_version = inspect.getsource(strategy.populate_any_indicators) == (
|
||||
@@ -822,7 +826,7 @@ class FreqaiDataKitchen:
|
||||
if tf not in corr_dataframes[p]:
|
||||
corr_dataframes[p][tf] = pd.DataFrame()
|
||||
|
||||
if not prediction_dataframe.empty:
|
||||
if prediction_dataframe is not None and not prediction_dataframe.empty:
|
||||
dataframe = prediction_dataframe.copy()
|
||||
base_dataframes[self.config["timeframe"]] = dataframe.copy()
|
||||
else:
|
||||
|
||||
@@ -618,7 +618,7 @@ class IFreqaiModel(ABC):
|
||||
)
|
||||
|
||||
unfiltered_dataframe = dk.use_strategy_to_populate_indicators(
|
||||
strategy, corr_dataframes, base_dataframes, pair
|
||||
strategy, corr_dataframes=corr_dataframes, base_dataframes=base_dataframes, pair=pair
|
||||
)
|
||||
|
||||
trained_timestamp = new_trained_timerange.stopts
|
||||
|
||||
@@ -25,7 +25,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
criterion: nn.Module,
|
||||
device: str,
|
||||
data_convertor: PyTorchDataConvertor,
|
||||
model_meta_data: dict[str, Any] = {},
|
||||
model_meta_data: dict[str, Any] | None = None,
|
||||
window_size: int = 1,
|
||||
tb_logger: Any = None,
|
||||
**kwargs,
|
||||
@@ -45,6 +45,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
:param n_epochs: The maximum number batches to use for evaluation.
|
||||
:param batch_size: The size of the batches to use during training.
|
||||
"""
|
||||
if model_meta_data is None:
|
||||
model_meta_data = {}
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
||||
@@ -287,8 +287,6 @@ max-complexity = 12
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"freqtrade/freqai/**/*.py" = [
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"B006", # Bugbear - mutable default argument
|
||||
"B008", # bugbear - Do not perform function calls in argument defaults
|
||||
]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # allow assert in tests
|
||||
|
||||
@@ -150,7 +150,9 @@ def test_get_pair_data_for_features_with_prealoaded_data(mocker, freqai_conf):
|
||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
||||
|
||||
_, base_df = freqai.dd.get_base_and_corr_dataframes(timerange, "LTC/BTC", freqai.dk)
|
||||
df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df)
|
||||
df = freqai.dk.get_pair_data_for_features(
|
||||
"LTC/BTC", "5m", strategy, {}, base_dataframes=base_df
|
||||
)
|
||||
|
||||
assert df is base_df["5m"]
|
||||
assert not df.empty
|
||||
@@ -170,7 +172,9 @@ def test_get_pair_data_for_features_without_preloaded_data(mocker, freqai_conf):
|
||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
||||
|
||||
base_df = {"5m": pd.DataFrame()}
|
||||
df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df)
|
||||
df = freqai.dk.get_pair_data_for_features(
|
||||
"LTC/BTC", "5m", strategy, {}, base_dataframes=base_df
|
||||
)
|
||||
|
||||
assert df is not base_df["5m"]
|
||||
assert not df.empty
|
||||
|
||||
Reference in New Issue
Block a user