mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
fix typing in TensorboardCallback
This commit is contained in:
@@ -127,6 +127,14 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.history: dict = {}
|
self.history: dict = {}
|
||||||
self.trade_history: list = []
|
self.trade_history: list = []
|
||||||
|
|
||||||
|
def get_attr(self, attr: str):
|
||||||
|
"""
|
||||||
|
Returns the attribute of the environment
|
||||||
|
:param attr: attribute to return
|
||||||
|
:return: attribute
|
||||||
|
"""
|
||||||
|
return getattr(self, attr)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_action_space(self):
|
def set_action_space(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Type, Union
|
|||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.logger import HParam
|
from stable_baselines3.common.logger import HParam
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ class TensorboardCallback(BaseCallback):
|
|||||||
super().__init__(verbose)
|
super().__init__(verbose)
|
||||||
self.model: Any = None
|
self.model: Any = None
|
||||||
self.logger = None # type: Any
|
self.logger = None # type: Any
|
||||||
self.training_env: Type[BaseEnvironment] = None # type: ignore
|
self.training_env: Union[BaseEnvironment, SubprocVecEnv, VecMonitor] = None
|
||||||
self.actions: Type[Enum] = actions
|
self.actions: Type[Enum] = actions
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
@@ -45,10 +45,7 @@ class TensorboardCallback(BaseCallback):
|
|||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
local_info = self.locals["infos"][0]
|
local_info = self.locals["infos"][0]
|
||||||
if isinstance(self.training_env, SubprocVecEnv):
|
|
||||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||||
else:
|
|
||||||
tensorboard_metrics = self.training_env.tensorboard_metrics
|
|
||||||
|
|
||||||
for metric in local_info:
|
for metric in local_info:
|
||||||
if metric not in ["episode", "terminal_observation"]:
|
if metric not in ["episode", "terminal_observation"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user