diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 15acde6fb..b9b6cdd96 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -12,12 +12,11 @@ import pandas as pd import torch as th import torch.multiprocessing from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.callbacks import BaseCallback, EvalCallback +from stable_baselines3.common.logger import HParam from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import SubprocVecEnv -from stable_baselines3.common.logger import HParam from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen @@ -157,7 +156,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=len(train_df), best_model_save_path=str(dk.data_path)) - + self.tensorboard_callback = TensorboardCallback() @abstractmethod @@ -403,6 +402,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, set_random_seed(seed) return _init + class TensorboardCallback(BaseCallback): """ Custom callback for plotting additional values in tensorboard. @@ -422,7 +422,7 @@ class TensorboardCallback(BaseCallback): metric_dict = { "eval/mean_reward": 0, "rollout/ep_rew_mean": 0, - "rollout/ep_len_mean":0 , + "rollout/ep_len_mean": 0, "train/value_loss": 0, "train/explained_variance": 0, } @@ -431,19 +431,21 @@ class TensorboardCallback(BaseCallback): HParam(hparam_dict, metric_dict), exclude=("stdout", "log", "json", "csv"), ) - + def _on_step(self) -> bool: custom_info = self.training_env.get_attr("custom_info")[0] - self.logger.record(f"_state/position", self.locals["infos"][0]["position"]) - self.logger.record(f"_state/trade_duration", self.locals["infos"][0]["trade_duration"]) - self.logger.record(f"_state/current_profit_pct", self.locals["infos"][0]["current_profit_pct"]) - self.logger.record(f"_reward/total_profit", self.locals["infos"][0]["total_profit"]) - self.logger.record(f"_reward/total_reward", self.locals["infos"][0]["total_reward"]) - self.logger.record_mean(f"_reward/mean_trade_duration", self.locals["infos"][0]["trade_duration"]) - self.logger.record(f"_actions/action", self.locals["infos"][0]["action"]) - self.logger.record(f"_actions/_Invalid", custom_info["Invalid"]) - self.logger.record(f"_actions/_Unknown", custom_info["Unknown"]) - self.logger.record(f"_actions/Hold", custom_info["Hold"]) + self.logger.record("_state/position", self.locals["infos"][0]["position"]) + self.logger.record("_state/trade_duration", self.locals["infos"][0]["trade_duration"]) + self.logger.record("_state/current_profit_pct", self.locals["infos"] + [0]["current_profit_pct"]) + self.logger.record("_reward/total_profit", self.locals["infos"][0]["total_profit"]) + self.logger.record("_reward/total_reward", self.locals["infos"][0]["total_reward"]) + self.logger.record_mean("_reward/mean_trade_duration", self.locals["infos"] + [0]["trade_duration"]) + self.logger.record("_actions/action", self.locals["infos"][0]["action"]) + self.logger.record("_actions/_Invalid", custom_info["Invalid"]) + self.logger.record("_actions/_Unknown", custom_info["Unknown"]) + self.logger.record("_actions/Hold", custom_info["Hold"]) for action in Actions: self.logger.record(f"_actions/{action.name}", custom_info[action.name]) - return True \ No newline at end of file + return True