diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 9f0b2d436..f014da602 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -2,7 +2,8 @@ import logging from typing import Any, Dict from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen @@ -55,9 +56,11 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_info=env_info) for i in range(self.max_threads)])) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=eval_freq, - best_model_save_path=str(dk.data_path)) + self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=eval_freq, + best_model_save_path=str(dk.data_path), + use_masking=(self.model_type == 'MaskablePPO' and + is_masking_supported(self.eval_env))) # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, # IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!