diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index d3395219a..e2c0f5fda 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -16,13 +16,13 @@ from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv -from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions +from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.persistence import Trade @@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): 'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) th.set_num_threads(self.max_threads) self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] - self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() - self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() + self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() + self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() self.eval_callback: Optional[EvalCallback] = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] @@ -431,7 +431,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): return 0. -def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, +def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, env_info: Dict[str, Any] = {}) -> Callable: """ diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 7f8c76956..c5511cf53 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam +from stable_baselines3.common.vec_env import SubprocVecEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment @@ -16,7 +17,7 @@ class TensorboardCallback(BaseCallback): super().__init__(verbose) self.model: Any = None self.logger = None # type: Any - self.training_env: BaseEnvironment = None # type: ignore + self.training_env: Type[BaseEnvironment] = None # type: ignore self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -44,7 +45,10 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] + if isinstance(self.training_env, SubprocVecEnv): + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] + else: + tensorboard_metrics = self.training_env.tensorboard_metrics for metric in local_info: if metric not in ["episode", "terminal_observation"]: diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 7eaaeab3e..ebc69452a 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -242,8 +242,8 @@ class IFreqaiModel(ABC): new_trained_timerange, pair, strategy, dk, data_load_timerange ) except Exception as msg: - logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. " - f"Message: {msg}, skipping.") + logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. " + f"Message: {msg}, skipping.") self.train_timer('stop', pair)