From 075c8c23c8bf50294e4a49b60466291dd63c2522 Mon Sep 17 00:00:00 2001 From: smarmau <42020297+smarmau@users.noreply.github.com> Date: Sat, 3 Dec 2022 21:16:04 +1100 Subject: [PATCH] add state/action info to callbacks --- .../prediction_models/ReinforcementLearner.py | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 61b01e21b..ff39a66e0 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -71,7 +71,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel): model.learn( total_timesteps=int(total_timesteps), - callback=self.eval_callback + callback=[self.eval_callback, self.tensorboard_callback] ) if Path(dk.data_path / "best_model.zip").is_file(): @@ -88,6 +88,33 @@ class ReinforcementLearner(BaseReinforcementLearningModel): User can override any function in BaseRLEnv and gym.Env. Here the user sets a custom reward based on profit and trade duration. """ + def reset(self): + + # Reset custom info + self.custom_info = {} + self.custom_info["Invalid"] = 0 + self.custom_info["Hold"] = 0 + self.custom_info["Unknown"] = 0 + self.custom_info["pnl_factor"] = 0 + self.custom_info["duration_factor"] = 0 + self.custom_info["reward_exit"] = 0 + self.custom_info["reward_hold"] = 0 + for action in Actions: + self.custom_info[f"{action.name}"] = 0 + return super().reset() + + def step(self, action: int): + observation, step_reward, done, info = super().step(action) + info = dict( + tick=self._current_tick, + action=action, + total_reward=self.total_reward, + total_profit=self._total_profit, + position=self._position.value, + trade_duration=self.get_trade_duration(), + current_profit_pct=self.get_unrealized_profit() + ) + return observation, step_reward, done, info def calculate_reward(self, action: int) -> float: """ @@ -100,17 +127,24 @@ class ReinforcementLearner(BaseReinforcementLearningModel): """ # first, penalize if the action is not valid if not self._is_valid(action): + self.custom_info["Invalid"] += 1 return -2 pnl = self.get_unrealized_profit() factor = 100. # reward agent for entering trades - if (action in (Actions.Long_enter.value, Actions.Short_enter.value) + if (action ==Actions.Long_enter.value and self._position == Positions.Neutral): + self.custom_info[f"{Actions.Long_enter.name}"] += 1 + return 25 + if (action == Actions.Short_enter.value + and self._position == Positions.Neutral): + self.custom_info[f"{Actions.Short_enter.name}"] += 1 return 25 # discourage agent from not entering trades if action == Actions.Neutral.value and self._position == Positions.Neutral: + self.custom_info[f"{Actions.Neutral.name}"] += 1 return -1 max_trade_duration = self.rl_config.get('max_trade_duration_candles', 300) @@ -124,18 +158,22 @@ class ReinforcementLearner(BaseReinforcementLearningModel): # discourage sitting in position if (self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value): + self.custom_info["Hold"] += 1 return -1 * trade_duration / max_trade_duration # close long if action == Actions.Long_exit.value and self._position == Positions.Long: if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) + self.custom_info[f"{Actions.Long_exit.name}"] += 1 return float(pnl * factor) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: if pnl > self.profit_aim * self.rr: factor *= self.rl_config['model_reward_parameters'].get('win_reward_factor', 2) + self.custom_info[f"{Actions.Short_exit.name}"] += 1 return float(pnl * factor) - + + self.custom_info["Unknown"] += 1 return 0.