diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 4d540ee36..8ee3c7c56 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -68,11 +68,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.unset_outlier_removal() self.net_arch = self.rl_config.get('net_arch', [128, 128]) self.dd.model_type = import_str - if self.activate_tensorboard: - self.tensorboard_callback: TensorboardCallback = \ - TensorboardCallback(verbose=1, actions=BaseActions) - else: - self.tenorboard_callback = None + self.tensorboard_callback: TensorboardCallback = \ + TensorboardCallback(verbose=1, actions=BaseActions) def unset_outlier_removal(self): """ @@ -159,10 +156,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): best_model_save_path=str(dk.data_path)) actions = self.train_env.get_actions() - if self.activate_tensorboard: - self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) - else: - self.tensorboard_callback = None # type: ignore + self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) def pack_env_dict(self, pair: str) -> Dict[str, Any]: """ diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 9cfda05ee..ae3876d2c 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -110,7 +110,7 @@ class IFreqaiModel(ABC): if self.ft_params.get('principal_component_analysis', False) and self.continual_learning: self.ft_params.update({'principal_component_analysis': False}) logger.warning('User tried to use PCA with continual learning. Deactivating PCA.') - self.activate_tensorboard: bool = self.freqai_info.get('activate_tensorboard', True) + self.activate_tensorboard: bool = self.freqai_info.get('activate_tensorboard', False) record_params(config, self.full_path) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 0d6c52445..a11decc92 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -58,10 +58,14 @@ class ReinforcementLearner(BaseReinforcementLearningModel): policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=self.net_arch) + if self.activate_tensorboard: + tb_path = Path(dk.full_path / "tensorboard" / dk.pair.split('/')[0]) + else: + tb_path = None + if dk.pair not in self.dd.model_dictionary or not self.continual_learning: model = self.MODELCLASS(self.policy_type, self.train_env, policy_kwargs=policy_kwargs, - tensorboard_log=Path( - dk.full_path / "tensorboard" / dk.pair.split('/')[0]), + tensorboard_log=tb_path, **self.freqai_info.get('model_training_parameters', {}) ) else: @@ -70,14 +74,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel): model = self.dd.model_dictionary[dk.pair] model.set_env(self.train_env) - callbacks = [self.eval_callback] - - if self.activate_tensorboard: - callbacks.append(self.tensorboard_callback) - model.learn( total_timesteps=int(total_timesteps), - callback=callbacks, + callback=[self.eval_callback, self.tensorboard_callback], progress_bar=self.rl_config.get('progress_bar', False) ) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 46e25462b..575e5c7e6 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -55,9 +55,9 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, can_run_model(model) - test_tb = True - if is_mac(): - test_tb = False + # test_tb = True + # if is_mac(): + # test_tb = False model_save_ext = 'joblib' freqai_conf.update({"freqaimodel": model}) @@ -94,7 +94,7 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, strategy.freqai_info = freqai_conf.get("freqai", {}) freqai = strategy.freqai freqai.live = True - freqai.activate_tensorboard = test_tb + # freqai.activate_tensorboard = test_tb freqai.can_short = can_short freqai.dk = FreqaiDataKitchen(freqai_conf) freqai.dk.live = True @@ -233,7 +233,7 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): ("CatboostRegressor", 2, "freqai_test_strat"), ("PyTorchMLPRegressor", 2, "freqai_test_strat"), ("PyTorchTransformerRegressor", 2, "freqai_test_strat"), - ("ReinforcementLearner", 3, "freqai_rl_test_strat"), + ("ReinforcementLearner", 2, "freqai_rl_test_strat"), ("XGBoostClassifier", 2, "freqai_test_classifier"), ("LightGBMClassifier", 2, "freqai_test_classifier"), ("CatboostClassifier", 2, "freqai_test_classifier"), @@ -242,9 +242,9 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): ) def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog): can_run_model(model) - test_tb = True - if is_mac(): - test_tb = False + # test_tb = True + # if is_mac(): + # test_tb = False freqai_conf.get("freqai", {}).update({"save_backtest_models": True}) freqai_conf['runmode'] = RunMode.BACKTEST @@ -277,7 +277,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog) strategy.freqai_info = freqai_conf.get("freqai", {}) freqai = strategy.freqai freqai.live = False - freqai.activate_tensorboard = test_tb + # freqai.activate_tensorboard = test_tb freqai.dk = FreqaiDataKitchen(freqai_conf) timerange = TimeRange.parse_timerange("20180110-20180130") freqai.dd.load_all_pair_histories(timerange, freqai.dk)