fix typing in TensorboardCallback

This commit is contained in:
robcaulk
2023-04-26 10:54:54 +02:00
parent 0a05099713
commit e29ce218eb
2 changed files with 11 additions and 6 deletions

View File

@@ -127,6 +127,14 @@ class BaseEnvironment(gym.Env):
self.history: dict = {} self.history: dict = {}
self.trade_history: list = [] self.trade_history: list = []
def get_attr(self, attr: str):
"""
Returns the attribute of the environment
:param attr: attribute to return
:return: attribute
"""
return getattr(self, attr)
@abstractmethod @abstractmethod
def set_action_space(self): def set_action_space(self):
""" """

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Type, Union
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam from stable_baselines3.common.logger import HParam
from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
@@ -17,7 +17,7 @@ class TensorboardCallback(BaseCallback):
super().__init__(verbose) super().__init__(verbose)
self.model: Any = None self.model: Any = None
self.logger = None # type: Any self.logger = None # type: Any
self.training_env: Type[BaseEnvironment] = None # type: ignore self.training_env: Union[BaseEnvironment, SubprocVecEnv, VecMonitor] = None
self.actions: Type[Enum] = actions self.actions: Type[Enum] = actions
def _on_training_start(self) -> None: def _on_training_start(self) -> None:
@@ -45,10 +45,7 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool: def _on_step(self) -> bool:
local_info = self.locals["infos"][0] local_info = self.locals["infos"][0]
if isinstance(self.training_env, SubprocVecEnv):
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
else:
tensorboard_metrics = self.training_env.tensorboard_metrics
for metric in local_info: for metric in local_info:
if metric not in ["episode", "terminal_observation"]: if metric not in ["episode", "terminal_observation"]: