mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-12-18 22:01:15 +00:00
fix mypy
This commit is contained in:
@@ -16,13 +16,13 @@ from pandas import DataFrame
|
|||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.utils import set_random_seed
|
from stable_baselines3.common.utils import set_random_seed
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||||
|
|
||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
|
||||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||||
from freqtrade.persistence import Trade
|
from freqtrade.persistence import Trade
|
||||||
|
|
||||||
@@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
|
||||||
th.set_num_threads(self.max_threads)
|
th.set_num_threads(self.max_threads)
|
||||||
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
|
||||||
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||||
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
|
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
|
||||||
self.eval_callback: Optional[EvalCallback] = None
|
self.eval_callback: Optional[EvalCallback] = None
|
||||||
self.model_type = self.freqai_info['rl_config']['model_type']
|
self.model_type = self.freqai_info['rl_config']['model_type']
|
||||||
self.rl_config = self.freqai_info['rl_config']
|
self.rl_config = self.freqai_info['rl_config']
|
||||||
@@ -431,7 +431,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
return 0.
|
return 0.
|
||||||
|
|
||||||
|
|
||||||
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
|
def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int,
|
||||||
seed: int, train_df: DataFrame, price: DataFrame,
|
seed: int, train_df: DataFrame, price: DataFrame,
|
||||||
env_info: Dict[str, Any] = {}) -> Callable:
|
env_info: Dict[str, Any] = {}) -> Callable:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Type, Union
|
|||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.logger import HParam
|
from stable_baselines3.common.logger import HParam
|
||||||
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ class TensorboardCallback(BaseCallback):
|
|||||||
super().__init__(verbose)
|
super().__init__(verbose)
|
||||||
self.model: Any = None
|
self.model: Any = None
|
||||||
self.logger = None # type: Any
|
self.logger = None # type: Any
|
||||||
self.training_env: BaseEnvironment = None # type: ignore
|
self.training_env: Type[BaseEnvironment] = None # type: ignore
|
||||||
self.actions: Type[Enum] = actions
|
self.actions: Type[Enum] = actions
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
@@ -44,7 +45,10 @@ class TensorboardCallback(BaseCallback):
|
|||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
|
||||||
local_info = self.locals["infos"][0]
|
local_info = self.locals["infos"][0]
|
||||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
if isinstance(self.training_env, SubprocVecEnv):
|
||||||
|
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||||
|
else:
|
||||||
|
tensorboard_metrics = self.training_env.tensorboard_metrics
|
||||||
|
|
||||||
for metric in local_info:
|
for metric in local_info:
|
||||||
if metric not in ["episode", "terminal_observation"]:
|
if metric not in ["episode", "terminal_observation"]:
|
||||||
|
|||||||
@@ -242,8 +242,8 @@ class IFreqaiModel(ABC):
|
|||||||
new_trained_timerange, pair, strategy, dk, data_load_timerange
|
new_trained_timerange, pair, strategy, dk, data_load_timerange
|
||||||
)
|
)
|
||||||
except Exception as msg:
|
except Exception as msg:
|
||||||
logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. "
|
logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. "
|
||||||
f"Message: {msg}, skipping.")
|
f"Message: {msg}, skipping.")
|
||||||
|
|
||||||
self.train_timer('stop', pair)
|
self.train_timer('stop', pair)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user