diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 8ee3c7c56..642a9edf2 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -13,7 +13,8 @@ import pandas as pd import torch as th import torch.multiprocessing from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor @@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() - self.eval_callback: Optional[EvalCallback] = None + self.eval_callback: Optional[MaskableEvalCallback] = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] self.df_raw: DataFrame = DataFrame() @@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info) self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info)) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=len(train_df), - best_model_save_path=str(dk.data_path)) + self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=len(train_df), + best_model_save_path=str(dk.data_path), + use_masking=(self.model_type == 'MaskablePPO' and + is_masking_supported(self.eval_env))) actions = self.train_env.get_actions() self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)