From 8cf0e4a316747a3c9456c9aa352d71dee1481f47 Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 26 Apr 2023 19:43:42 +0200 Subject: [PATCH] Fix mypy typing errors --- freqtrade/freqai/RL/TensorboardCallback.py | 10 ++++++++-- .../freqai/prediction_models/ReinforcementLearner.py | 7 +++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 784dc848d..61652c9c6 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam +from stable_baselines3.common.vec_env import VecEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions @@ -12,10 +13,13 @@ class TensorboardCallback(BaseCallback): Custom callback for plotting additional values in tensorboard and episodic summary reports. """ + # Override training_env type to fix type errors + training_env: Union[VecEnv, None] = None + def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): super().__init__(verbose) self.model: Any = None - self.logger = None # type: Any + self.logger: Any = None self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -43,7 +47,9 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] # type: ignore + if self.training_env is None: + return True + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] for metric in local_info: if metric not in ["episode", "terminal_observation"]: diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 65990da87..a5c2e12b5 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -1,11 +1,12 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Type import torch as th from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions +from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel @@ -84,7 +85,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel): return model - class MyRLEnv(Base5ActionRLEnv): + MyRLEnv: Type[BaseEnvironment] + + class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef] """ User can override any function in BaseRLEnv and gym.Env. Here the user sets a custom reward based on profit and trade duration.