From d6f45a12ae0778c6de86bd8020a69299ee474d31 Mon Sep 17 00:00:00 2001 From: smarmau <42020297+smarmau@users.noreply.github.com> Date: Sat, 3 Dec 2022 22:30:04 +1100 Subject: [PATCH] add multiproc fix flake8 --- freqtrade/freqai/prediction_models/ReinforcementLearner.py | 6 +++--- .../prediction_models/ReinforcementLearner_multiproc.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index ff39a66e0..fa1087497 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -102,7 +102,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): for action in Actions: self.custom_info[f"{action.name}"] = 0 return super().reset() - + def step(self, action: int): observation, step_reward, done, info = super().step(action) info = dict( @@ -134,7 +134,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): factor = 100. # reward agent for entering trades - if (action ==Actions.Long_enter.value + if (action == Actions.Long_enter.value and self._position == Positions.Neutral): self.custom_info[f"{Actions.Long_enter.name}"] += 1 return 25 @@ -174,6 +174,6 @@ class ReinforcementLearner(BaseReinforcementLearningModel): factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) self.custom_info[f"{Actions.Short_exit.name}"] += 1 return float(pnl * factor) - + self.custom_info["Unknown"] += 1 return 0. diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 56636c1f6..dd5430aa7 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner -from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env +from freqtrade.freqai.RL.BaseReinforcementLearningModel import TensorboardCallback, make_env logger = logging.getLogger(__name__) @@ -49,3 +49,5 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) + + self.tensorboard_callback = TensorboardCallback()