diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 42e644f0a..91c7501c6 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -2,7 +2,7 @@ import logging import random from abc import abstractmethod from enum import Enum -from typing import Optional, Type, Union +from typing import List, Optional, Type, Union import gymnasium as gym import numpy as np @@ -141,6 +141,9 @@ class BaseEnvironment(gym.Env): Unique to the environment action count. Must be inherited. """ + def action_masks(self) -> List[bool]: + return [self._is_valid(action.value) for action in self.actions] + def seed(self, seed: int = 1): self.np_random, seed = seeding.np_random(seed) return [seed] diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 8ee3c7c56..642a9edf2 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -13,7 +13,8 @@ import pandas as pd import torch as th import torch.multiprocessing from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor @@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() - self.eval_callback: Optional[EvalCallback] = None + self.eval_callback: Optional[MaskableEvalCallback] = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] self.df_raw: DataFrame = DataFrame() @@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info) self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info)) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=len(train_df), - best_model_save_path=str(dk.data_path)) + self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=len(train_df), + best_model_save_path=str(dk.data_path), + use_masking=(self.model_type == 'MaskablePPO' and + is_masking_supported(self.eval_env))) actions = self.train_env.get_actions() self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 9f0b2d436..f014da602 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -2,7 +2,8 @@ import logging from typing import Any, Dict from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen @@ -55,9 +56,11 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_info=env_info) for i in range(self.max_threads)])) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=eval_freq, - best_model_save_path=str(dk.data_path)) + self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=eval_freq, + best_model_save_path=str(dk.data_path), + use_masking=(self.model_type == 'MaskablePPO' and + is_masking_supported(self.eval_env))) # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, # IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!