From afd54d39a5b675e0ce19afcb2bfbc9c2465a9d16 Mon Sep 17 00:00:00 2001 From: steam Date: Sun, 11 Jun 2023 20:00:12 +0300 Subject: [PATCH 1/4] add action_masks --- freqtrade/freqai/RL/BaseEnvironment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 42e644f0a..d1a399c48 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -141,6 +141,9 @@ class BaseEnvironment(gym.Env): Unique to the environment action count. Must be inherited. """ + def action_masks(self) -> list[bool]: + return [self._is_valid(action.value) for action in self.actions] + def seed(self, seed: int = 1): self.np_random, seed = seeding.np_random(seed) return [seed] From c36547a5632c35bc76a7a977b51c3682c87531cf Mon Sep 17 00:00:00 2001 From: steam Date: Sun, 11 Jun 2023 20:05:53 +0300 Subject: [PATCH 2/4] add maskable eval callback --- .../freqai/RL/BaseReinforcementLearningModel.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index 8ee3c7c56..642a9edf2 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -13,7 +13,8 @@ import pandas as pd import torch as th import torch.multiprocessing from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor @@ -48,7 +49,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() - self.eval_callback: Optional[EvalCallback] = None + self.eval_callback: Optional[MaskableEvalCallback] = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] self.df_raw: DataFrame = DataFrame() @@ -151,9 +152,11 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.train_env = self.MyRLEnv(df=train_df, prices=prices_train, **env_info) self.eval_env = Monitor(self.MyRLEnv(df=test_df, prices=prices_test, **env_info)) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=len(train_df), - best_model_save_path=str(dk.data_path)) + self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=len(train_df), + best_model_save_path=str(dk.data_path), + use_masking=(self.model_type == 'MaskablePPO' and + is_masking_supported(self.eval_env))) actions = self.train_env.get_actions() self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) From 5dee86eda7760d646927e9f064c094dc0a143fa3 Mon Sep 17 00:00:00 2001 From: steam Date: Sun, 11 Jun 2023 21:44:57 +0300 Subject: [PATCH 3/4] fix action_masks typing list --- freqtrade/freqai/RL/BaseEnvironment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index d1a399c48..91c7501c6 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -2,7 +2,7 @@ import logging import random from abc import abstractmethod from enum import Enum -from typing import Optional, Type, Union +from typing import List, Optional, Type, Union import gymnasium as gym import numpy as np @@ -141,7 +141,7 @@ class BaseEnvironment(gym.Env): Unique to the environment action count. Must be inherited. """ - def action_masks(self) -> list[bool]: + def action_masks(self) -> List[bool]: return [self._is_valid(action.value) for action in self.actions] def seed(self, seed: int = 1): From bdb535d0e689ca0d6bf2c83fb2f13153d329976a Mon Sep 17 00:00:00 2001 From: steam Date: Sun, 11 Jun 2023 22:20:15 +0300 Subject: [PATCH 4/4] add maskable eval callback multiproc --- .../ReinforcementLearner_multiproc.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 9f0b2d436..f014da602 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -2,7 +2,8 @@ import logging from typing import Any, Dict from pandas import DataFrame -from stable_baselines3.common.callbacks import EvalCallback +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback +from sb3_contrib.common.maskable.utils import is_masking_supported from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen @@ -55,9 +56,11 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_info=env_info) for i in range(self.max_threads)])) - self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=eval_freq, - best_model_save_path=str(dk.data_path)) + self.eval_callback = MaskableEvalCallback(self.eval_env, deterministic=True, + render=False, eval_freq=eval_freq, + best_model_save_path=str(dk.data_path), + use_masking=(self.model_type == 'MaskablePPO' and + is_masking_supported(self.eval_env))) # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, # IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!