Fix mypy typing errors

This commit is contained in:
Matthias
2023-04-26 19:43:42 +02:00
parent 6d3c94a739
commit 8cf0e4a316
2 changed files with 13 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.vec_env import VecEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions
@@ -12,10 +13,13 @@ class TensorboardCallback(BaseCallback):
Custom callback for plotting additional values in tensorboard and
episodic summary reports.
"""
# Override training_env type to fix type errors
training_env: Union[VecEnv, None] = None
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
super().__init__(verbose)
self.model: Any = None
self.logger = None # type: Any
self.logger: Any = None
self.actions: Type[Enum] = actions
def _on_training_start(self) -> None:
@@ -43,7 +47,9 @@ class TensorboardCallback(BaseCallback):
def _on_step(self) -> bool:
local_info = self.locals["infos"][0]
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] # type: ignore
if self.training_env is None:
return True
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
for metric in local_info:
if metric not in ["episode", "terminal_observation"]:

View File

@@ -1,11 +1,12 @@
import logging
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, Type
import torch as th
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment
from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel
@@ -84,7 +85,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
return model
class MyRLEnv(Base5ActionRLEnv):
MyRLEnv: Type[BaseEnvironment]
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
"""
User can override any function in BaseRLEnv and gym.Env. Here the user
sets a custom reward based on profit and trade duration.