From d03fe1f8eee09a5f4e6aa936bceb50e4dc983c04 Mon Sep 17 00:00:00 2001 From: Richard Jozsa <38407205+richardjozsa@users.noreply.github.com> Date: Thu, 16 Mar 2023 00:53:37 +0100 Subject: [PATCH 01/11] add latest experimental version of gymnasium --- freqtrade/freqai/RL/Base3ActionRLEnv.py | 7 +++++-- freqtrade/freqai/RL/Base4ActionRLEnv.py | 7 +++++-- freqtrade/freqai/RL/Base5ActionRLEnv.py | 6 ++++-- freqtrade/freqai/RL/BaseEnvironment.py | 8 ++++---- freqtrade/freqai/RL/BaseReinforcementLearningModel.py | 2 +- requirements-freqai-rl.txt | 8 +++++--- 6 files changed, 24 insertions(+), 14 deletions(-) diff --git a/freqtrade/freqai/RL/Base3ActionRLEnv.py b/freqtrade/freqai/RL/Base3ActionRLEnv.py index 3b5fffc58..83682263b 100644 --- a/freqtrade/freqai/RL/Base3ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base3ActionRLEnv.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from gym import spaces +from gymnasium import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -94,9 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment): observation = self._get_observation() + #user can play with time if they want + truncated = False + self._update_history(info) - return observation, step_reward, self._done, info + return observation, step_reward, self._done,truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index 8f45028b1..b26ba988a 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from gym import spaces +from gymnasium import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -106,9 +106,12 @@ class Base4ActionRLEnv(BaseEnvironment): observation = self._get_observation() + #user can play with time if they want + truncated = False + self._update_history(info) - return observation, step_reward, self._done, info + return observation, step_reward, self._done,truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 22d3cae30..6ce598dfb 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from gym import spaces +from gymnasium import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -111,10 +111,12 @@ class Base5ActionRLEnv(BaseEnvironment): ) observation = self._get_observation() + #user can play with time if they want + truncated = False self._update_history(info) - return observation, step_reward, self._done, info + return observation, step_reward, self._done,truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 7a4467bf7..60b65cc03 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -4,11 +4,11 @@ from abc import abstractmethod from enum import Enum from typing import Optional, Type, Union -import gym +import gymnasium as gym import numpy as np import pandas as pd -from gym import spaces -from gym.utils import seeding +from gymnasium import spaces +from gymnasium.utils import seeding from pandas import DataFrame @@ -195,7 +195,7 @@ class BaseEnvironment(gym.Env): self.close_trade_profit = [] self._total_unrealized_profit = 1 - return self._get_observation() + return self._get_observation(), self.history @abstractmethod def step(self, action: int): diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e10880f46..e18419d75 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Type, Union -import gym +import gymnasium as gym import numpy as np import numpy.typing as npt import pandas as pd diff --git a/requirements-freqai-rl.txt b/requirements-freqai-rl.txt index 4de7d8fab..233876425 100644 --- a/requirements-freqai-rl.txt +++ b/requirements-freqai-rl.txt @@ -3,8 +3,10 @@ # Required for freqai-rl torch==1.13.1; python_version < '3.11' -stable-baselines3==1.7.0; python_version < '3.11' -sb3-contrib==1.7.0; python_version < '3.11' +#until these branches will be released we can use this +git+https://github.com/Farama-Foundation/Gymnasium@main +git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support +git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support # Gym is forced to this version by stable-baselines3. setuptools==65.5.1 # Should be removed when gym is fixed. -gym==0.21; python_version < '3.11' + From 66c326b78935bb712a5b01fb81b570e9a7e61684 Mon Sep 17 00:00:00 2001 From: Richard Jozsa <38407205+richardjozsa@users.noreply.github.com> Date: Mon, 20 Mar 2023 15:54:58 +0100 Subject: [PATCH 02/11] Add proper handling of multiple environments --- .../RL/BaseReinforcementLearningModel.py | 4 +--- .../ReinforcementLearner_multiproc.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e18419d75..e36d5ea5d 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -433,7 +433,6 @@ class BaseReinforcementLearningModel(IFreqaiModel): def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, - monitor: bool = False, env_info: Dict[str, Any] = {}) -> Callable: """ Utility function for multiprocessed env. @@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info) - if monitor: - env = Monitor(env) + return env set_random_seed(seed) return _init diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index b3b8c40e6..c215e380b 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -4,7 +4,7 @@ from typing import Any, Dict from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.vec_env import SubprocVecEnv - +from stable_baselines3.common.vec_env import VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env @@ -41,22 +41,27 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_info = self.pack_env_dict(dk.pair) + eval_freq = len(train_df) // self.max_threads + env_id = "train_env" - self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, + self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, train_df, prices_train, - monitor=True, + env_info=env_info) for i - in range(self.max_threads)]) + in range(self.max_threads)])) eval_env_id = 'eval_env' - self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, + self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, test_df, prices_test, - monitor=True, + env_info=env_info) for i - in range(self.max_threads)]) + in range(self.max_threads)])) + self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=len(train_df), + render=False, eval_freq=eval_freq, best_model_save_path=str(dk.data_path)) + + # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!! actions = self.train_env.env_method("get_actions")[0] self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) From c055f82e9a05a175f8b66142ed0b989b2a390f77 Mon Sep 17 00:00:00 2001 From: Richard Jozsa <38407205+richardjozsa@users.noreply.github.com> Date: Sun, 16 Apr 2023 19:28:36 +0200 Subject: [PATCH 03/11] Pip release follow up --- requirements-freqai-rl.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements-freqai-rl.txt b/requirements-freqai-rl.txt index 233876425..97d8e2c9b 100644 --- a/requirements-freqai-rl.txt +++ b/requirements-freqai-rl.txt @@ -4,9 +4,9 @@ # Required for freqai-rl torch==1.13.1; python_version < '3.11' #until these branches will be released we can use this -git+https://github.com/Farama-Foundation/Gymnasium@main -git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support -git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support +gymnasium==0.28.1 +stable_baselines3>=2.0.0a1 +sb3_contrib>=2.0.0a1 # Gym is forced to this version by stable-baselines3. setuptools==65.5.1 # Should be removed when gym is fixed. From 3fb5cd3df6efa96e65d6fa79dafb420641ef06b4 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 17 Apr 2023 20:27:18 +0200 Subject: [PATCH 04/11] Improve formatting --- freqtrade/freqai/RL/Base3ActionRLEnv.py | 4 ++-- freqtrade/freqai/RL/Base4ActionRLEnv.py | 4 ++-- freqtrade/freqai/RL/Base5ActionRLEnv.py | 4 ++-- .../RL/BaseReinforcementLearningModel.py | 2 +- .../ReinforcementLearner_multiproc.py | 24 +++++++++---------- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/freqtrade/freqai/RL/Base3ActionRLEnv.py b/freqtrade/freqai/RL/Base3ActionRLEnv.py index bb773658e..538ca3a6a 100644 --- a/freqtrade/freqai/RL/Base3ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base3ActionRLEnv.py @@ -94,12 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment): observation = self._get_observation() - #user can play with time if they want + # user can play with time if they want truncated = False self._update_history(info) - return observation, step_reward, self._done,truncated, info + return observation, step_reward, self._done, truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index aebed71ab..12f10d4fc 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -96,12 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment): observation = self._get_observation() - #user can play with time if they want + # user can play with time if they want truncated = False self._update_history(info) - return observation, step_reward, self._done,truncated, info + return observation, step_reward, self._done, truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index d61c1a393..35d04f942 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -101,12 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment): ) observation = self._get_observation() - #user can play with time if they want + # user can play with time if they want truncated = False self._update_history(info) - return observation, step_reward, self._done,truncated, info + return observation, step_reward, self._done, truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e36d5ea5d..d3395219a 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -449,7 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info) - + return env set_random_seed(seed) return _init diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index c215e380b..73f617027 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -3,8 +3,8 @@ from typing import Any, Dict from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback -from stable_baselines3.common.vec_env import SubprocVecEnv -from stable_baselines3.common.vec_env import VecMonitor +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor + from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env @@ -45,23 +45,21 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_id = "train_env" self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, - train_df, prices_train, - - env_info=env_info) for i - in range(self.max_threads)])) + train_df, prices_train, + env_info=env_info) for i + in range(self.max_threads)])) eval_env_id = 'eval_env' self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, - test_df, prices_test, - - env_info=env_info) for i - in range(self.max_threads)])) - + test_df, prices_test, + env_info=env_info) for i + in range(self.max_threads)])) + self.eval_callback = EvalCallback(self.eval_env, deterministic=True, render=False, eval_freq=eval_freq, best_model_save_path=str(dk.data_path)) - - # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!! + # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, + # IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!! actions = self.train_env.env_method("get_actions")[0] self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) From f1e03a68739fd62d36eb7f893ef6b1e233f7dc64 Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 19 Apr 2023 18:20:25 +0200 Subject: [PATCH 05/11] Update variable to better reflect it's content --- freqtrade/rpc/api_server/api_v1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/freqtrade/rpc/api_server/api_v1.py b/freqtrade/rpc/api_server/api_v1.py index 8ea70bb69..8aa706e62 100644 --- a/freqtrade/rpc/api_server/api_v1.py +++ b/freqtrade/rpc/api_server/api_v1.py @@ -303,11 +303,11 @@ def get_strategy(strategy: str, config=Depends(get_config)): @router.get('/freqaimodels', response_model=FreqAIModelListResponse, tags=['freqai']) def list_freqaimodels(config=Depends(get_config)): from freqtrade.resolvers.freqaimodel_resolver import FreqaiModelResolver - strategies = FreqaiModelResolver.search_all_objects( + models = FreqaiModelResolver.search_all_objects( config, False) - strategies = sorted(strategies, key=lambda x: x['name']) + models = sorted(models, key=lambda x: x['name']) - return {'freqaimodels': [x['name'] for x in strategies]} + return {'freqaimodels': [x['name'] for x in models]} @router.get('/available_pairs', response_model=AvailablePairs, tags=['candle data']) From 0a05099713096d4c21f816d0e03c5bef7d64f3ff Mon Sep 17 00:00:00 2001 From: robcaulk Date: Fri, 21 Apr 2023 22:52:19 +0200 Subject: [PATCH 06/11] fix mypy --- freqtrade/freqai/RL/BaseReinforcementLearningModel.py | 10 +++++----- freqtrade/freqai/RL/TensorboardCallback.py | 8 ++++++-- freqtrade/freqai/freqai_interface.py | 4 ++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index d3395219a..e2c0f5fda 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -16,13 +16,13 @@ from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv -from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions +from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.persistence import Trade @@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): 'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) th.set_num_threads(self.max_threads) self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] - self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() - self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() + self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() + self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() self.eval_callback: Optional[EvalCallback] = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] @@ -431,7 +431,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): return 0. -def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, +def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, env_info: Dict[str, Any] = {}) -> Callable: """ diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 7f8c76956..c5511cf53 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam +from stable_baselines3.common.vec_env import SubprocVecEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment @@ -16,7 +17,7 @@ class TensorboardCallback(BaseCallback): super().__init__(verbose) self.model: Any = None self.logger = None # type: Any - self.training_env: BaseEnvironment = None # type: ignore + self.training_env: Type[BaseEnvironment] = None # type: ignore self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -44,7 +45,10 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] + if isinstance(self.training_env, SubprocVecEnv): + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] + else: + tensorboard_metrics = self.training_env.tensorboard_metrics for metric in local_info: if metric not in ["episode", "terminal_observation"]: diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 7eaaeab3e..ebc69452a 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -242,8 +242,8 @@ class IFreqaiModel(ABC): new_trained_timerange, pair, strategy, dk, data_load_timerange ) except Exception as msg: - logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. " - f"Message: {msg}, skipping.") + logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. " + f"Message: {msg}, skipping.") self.train_timer('stop', pair) From e29ce218ebd3dee527566a40621015c5a4ac982d Mon Sep 17 00:00:00 2001 From: robcaulk Date: Wed, 26 Apr 2023 10:54:54 +0200 Subject: [PATCH 07/11] fix typing in TensorboardCallback --- freqtrade/freqai/RL/BaseEnvironment.py | 8 ++++++++ freqtrade/freqai/RL/TensorboardCallback.py | 9 +++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 081d41202..08bb93347 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -127,6 +127,14 @@ class BaseEnvironment(gym.Env): self.history: dict = {} self.trade_history: list = [] + def get_attr(self, attr: str): + """ + Returns the attribute of the environment + :param attr: attribute to return + :return: attribute + """ + return getattr(self, attr) + @abstractmethod def set_action_space(self): """ diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index c5511cf53..12dcd8b9d 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam -from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment @@ -17,7 +17,7 @@ class TensorboardCallback(BaseCallback): super().__init__(verbose) self.model: Any = None self.logger = None # type: Any - self.training_env: Type[BaseEnvironment] = None # type: ignore + self.training_env: Union[BaseEnvironment, SubprocVecEnv, VecMonitor] = None self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -45,10 +45,7 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - if isinstance(self.training_env, SubprocVecEnv): - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] - else: - tensorboard_metrics = self.training_env.tensorboard_metrics + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] for metric in local_info: if metric not in ["episode", "terminal_observation"]: From e86980befacd6dd9249b57aaf43c65beef50f125 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Wed, 26 Apr 2023 13:42:10 +0200 Subject: [PATCH 08/11] remove typing from callback init --- freqtrade/freqai/RL/TensorboardCallback.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 12dcd8b9d..282e60b0d 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam -from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecMonitor from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment @@ -17,7 +17,8 @@ class TensorboardCallback(BaseCallback): super().__init__(verbose) self.model: Any = None self.logger = None # type: Any - self.training_env: Union[BaseEnvironment, SubprocVecEnv, VecMonitor] = None + self.training_env: Union[BaseEnvironment, SubprocVecEnv, + VecMonitor, DummyVecEnv] = DummyVecEnv() self.actions: Type[Enum] = actions def _on_training_start(self) -> None: From c6f3a3bbca4472918c8a60a2784075936621483a Mon Sep 17 00:00:00 2001 From: robcaulk Date: Wed, 26 Apr 2023 14:11:26 +0200 Subject: [PATCH 09/11] avoid typing issues in the tensorboard callback --- freqtrade/freqai/RL/TensorboardCallback.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 282e60b0d..3924f9d2c 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,9 +3,8 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecMonitor -from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment +from freqtrade.freqai.RL.BaseEnvironment import BaseActions class TensorboardCallback(BaseCallback): @@ -17,8 +16,6 @@ class TensorboardCallback(BaseCallback): super().__init__(verbose) self.model: Any = None self.logger = None # type: Any - self.training_env: Union[BaseEnvironment, SubprocVecEnv, - VecMonitor, DummyVecEnv] = DummyVecEnv() self.actions: Type[Enum] = actions def _on_training_start(self) -> None: From 6d3c94a7398718566a91a89ede0c717e3b858b3c Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 26 Apr 2023 18:08:55 +0200 Subject: [PATCH 10/11] type: ignore the offending tensorflow call --- freqtrade/freqai/RL/TensorboardCallback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 3924f9d2c..784dc848d 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -43,7 +43,7 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] # type: ignore for metric in local_info: if metric not in ["episode", "terminal_observation"]: From 8cf0e4a316747a3c9456c9aa352d71dee1481f47 Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 26 Apr 2023 19:43:42 +0200 Subject: [PATCH 11/11] Fix mypy typing errors --- freqtrade/freqai/RL/TensorboardCallback.py | 10 ++++++++-- .../freqai/prediction_models/ReinforcementLearner.py | 7 +++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 784dc848d..61652c9c6 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam +from stable_baselines3.common.vec_env import VecEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions @@ -12,10 +13,13 @@ class TensorboardCallback(BaseCallback): Custom callback for plotting additional values in tensorboard and episodic summary reports. """ + # Override training_env type to fix type errors + training_env: Union[VecEnv, None] = None + def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): super().__init__(verbose) self.model: Any = None - self.logger = None # type: Any + self.logger: Any = None self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -43,7 +47,9 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] # type: ignore + if self.training_env is None: + return True + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] for metric in local_info: if metric not in ["episode", "terminal_observation"]: diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 65990da87..a5c2e12b5 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -1,11 +1,12 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Type import torch as th from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions +from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel @@ -84,7 +85,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel): return model - class MyRLEnv(Base5ActionRLEnv): + MyRLEnv: Type[BaseEnvironment] + + class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef] """ User can override any function in BaseRLEnv and gym.Env. Here the user sets a custom reward based on profit and trade duration.