This commit is contained in:
robcaulk
2023-04-21 22:52:19 +02:00
parent f30fc29da0
commit 0a05099713
3 changed files with 13 additions and 9 deletions

View File

@@ -16,13 +16,13 @@ from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
@@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) 'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads) th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None self.eval_callback: Optional[EvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type'] self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config'] self.rl_config = self.freqai_info['rl_config']
@@ -431,7 +431,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return 0. return 0.
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame, seed: int, train_df: DataFrame, price: DataFrame,
env_info: Dict[str, Any] = {}) -> Callable: env_info: Dict[str, Any] = {}) -> Callable:
""" """

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam from stable_baselines3.common.logger import HParam
from stable_baselines3.common.vec_env import SubprocVecEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
@@ -16,7 +17,7 @@ class TensorboardCallback(BaseCallback):
super().__init__(verbose) super().__init__(verbose)
self.model: Any = None self.model: Any = None
self.logger = None # type: Any self.logger = None # type: Any
self.training_env: BaseEnvironment = None # type: ignore self.training_env: Type[BaseEnvironment] = None # type: ignore
self.actions: Type[Enum] = actions self.actions: Type[Enum] = actions
def _on_training_start(self) -> None: def _on_training_start(self) -> None:
@@ -44,7 +45,10 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool: def _on_step(self) -> bool:
local_info = self.locals["infos"][0] local_info = self.locals["infos"][0]
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] if isinstance(self.training_env, SubprocVecEnv):
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
else:
tensorboard_metrics = self.training_env.tensorboard_metrics
for metric in local_info: for metric in local_info:
if metric not in ["episode", "terminal_observation"]: if metric not in ["episode", "terminal_observation"]:

View File

@@ -242,8 +242,8 @@ class IFreqaiModel(ABC):
new_trained_timerange, pair, strategy, dk, data_load_timerange new_trained_timerange, pair, strategy, dk, data_load_timerange
) )
except Exception as msg: except Exception as msg:
logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. " logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. "
f"Message: {msg}, skipping.") f"Message: {msg}, skipping.")
self.train_timer('stop', pair) self.train_timer('stop', pair)