diff --git a/freqtrade/freqai/tensorboard/TensorboardCallback.py b/freqtrade/freqai/tensorboard/TensorboardCallback.py index 2be917616..b8a351498 100644 --- a/freqtrade/freqai/tensorboard/TensorboardCallback.py +++ b/freqtrade/freqai/tensorboard/TensorboardCallback.py @@ -3,7 +3,6 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam -from stable_baselines3.common.vec_env import VecEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions @@ -13,13 +12,9 @@ class TensorboardCallback(BaseCallback): Custom callback for plotting additional values in tensorboard and episodic summary reports. """ - # Override training_env type to fix type errors - training_env: Union[VecEnv, None] = None - def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): super().__init__(verbose) self.model: Any = None - self.logger: Any = None self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -47,8 +42,6 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - if self.training_env is None: - return True if hasattr(self.training_env, 'envs'): tensorboard_metrics = self.training_env.envs[0].unwrapped.tensorboard_metrics diff --git a/requirements-freqai-rl.txt b/requirements-freqai-rl.txt index c2cca5427..fba25d409 100644 --- a/requirements-freqai-rl.txt +++ b/requirements-freqai-rl.txt @@ -5,7 +5,7 @@ torch==2.0.1 #until these branches will be released we can use this gymnasium==0.29.1 -stable_baselines3==2.1.0 +stable_baselines3==2.2.1 sb3_contrib>=2.0.0a9 # Progress bar for stable-baselines3 and sb3-contrib tqdm==4.66.1