From 66c326b78935bb712a5b01fb81b570e9a7e61684 Mon Sep 17 00:00:00 2001 From: Richard Jozsa <38407205+richardjozsa@users.noreply.github.com> Date: Mon, 20 Mar 2023 15:54:58 +0100 Subject: [PATCH] Add proper handling of multiple environments --- .../RL/BaseReinforcementLearningModel.py | 4 +--- .../ReinforcementLearner_multiproc.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e18419d75..e36d5ea5d 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -433,7 +433,6 @@ class BaseReinforcementLearningModel(IFreqaiModel): def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, - monitor: bool = False, env_info: Dict[str, Any] = {}) -> Callable: """ Utility function for multiprocessed env. @@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info) - if monitor: - env = Monitor(env) + return env set_random_seed(seed) return _init diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index b3b8c40e6..c215e380b 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -4,7 +4,7 @@ from typing import Any, Dict from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.vec_env import SubprocVecEnv - +from stable_baselines3.common.vec_env import VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env @@ -41,22 +41,27 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_info = self.pack_env_dict(dk.pair) + eval_freq = len(train_df) // self.max_threads + env_id = "train_env" - self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, + self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, - monitor=True, + env_info=env_info) for i - in range(self.max_threads)]) + in range(self.max_threads)])) eval_env_id = 'eval_env' - self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, + self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, test_df, prices_test, - monitor=True, + env_info=env_info) for i - in range(self.max_threads)]) + in range(self.max_threads)])) + self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=len(train_df), + render=False, eval_freq=eval_freq, best_model_save_path=str(dk.data_path)) + + # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!! actions = self.train_env.env_method("get_actions")[0] self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)