diff --git a/freqtrade/strategy/strategyupdater.py b/freqtrade/strategy/strategyupdater.py index 998a3aac3..8a2f4e72c 100644 --- a/freqtrade/strategy/strategyupdater.py +++ b/freqtrade/strategy/strategyupdater.py @@ -39,6 +39,17 @@ class StrategyUpdater: "sell": "exit", } + # Update function names. + # example: `np.NaN` was removed in the NumPy 2.0 release. Use `np.nan` instead. + module_replacements = { + "numpy": { + "aliases": set(), + "replacements": [ + ("NaN", "nan"), + ], + } + } + # create a dictionary that maps the old column names to the new ones rename_dict = {"buy": "enter_long", "sell": "exit_long", "buy_tag": "enter_tag"} @@ -153,16 +164,24 @@ class NameUpdater(ast_comments.NodeTransformer): def visit_Name(self, node): # if the name is in the mapping, update it node.id = self.check_dict(StrategyUpdater.name_mapping, node.id) + + for mod, info in StrategyUpdater.module_replacements.items(): + for old_attr, new_attr in info["replacements"]: + if node.id == old_attr: + node.id = new_attr return node def visit_Import(self, node): - # do not update the names in import statements + for alias in node.names: + if alias.name in StrategyUpdater.module_replacements: + as_name = alias.asname or alias.name + StrategyUpdater.module_replacements[alias.name]["aliases"].add(as_name) return node def visit_ImportFrom(self, node): - # if hasattr(node, "module"): - # if node.module == "freqtrade.strategy.hyper": - # node.module = "freqtrade.strategy" + if node.module in StrategyUpdater.module_replacements: + mod = node.module + StrategyUpdater.module_replacements[node.module]["aliases"].add(mod) return node def visit_If(self, node: ast_comments.If): @@ -182,6 +201,12 @@ class NameUpdater(ast_comments.NodeTransformer): and node.attr == "nr_of_successful_buys" ): node.attr = "nr_of_successful_entries" + if isinstance(node.value, ast_comments.Name): + for mod, info in StrategyUpdater.module_replacements.items(): + if node.value.id in info["aliases"]: + for old_attr, new_attr in info["replacements"]: + if node.attr == old_attr: + node.attr = new_attr return node def visit_ClassDef(self, node): diff --git a/tests/test_strategy_updater.py b/tests/test_strategy_updater.py index 96daee973..48f1d27d7 100644 --- a/tests/test_strategy_updater.py +++ b/tests/test_strategy_updater.py @@ -42,8 +42,10 @@ def test_strategy_updater_methods(default_conf, caplog) -> None: instance_strategy_updater = StrategyUpdater() modified_code1 = instance_strategy_updater.update_code( """ +import numpy as np class testClass(IStrategy): def populate_buy_trend(): + some_variable = np.NaN pass def populate_sell_trend(): pass @@ -62,6 +64,7 @@ class testClass(IStrategy): assert "check_exit_timeout" in modified_code1 assert "custom_exit" in modified_code1 assert "INTERFACE_VERSION = 3" in modified_code1 + assert "np.nan" in modified_code1 def test_strategy_updater_params(default_conf, caplog) -> None: