This commit is contained in:
robcaulk
2023-04-21 22:52:19 +02:00
parent f30fc29da0
commit 0a05099713
3 changed files with 13 additions and 9 deletions

View File

@@ -16,13 +16,13 @@ from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade
@@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel):
'cpu_count', 1), max(int(self.max_system_threads / 2), 1))
th.set_num_threads(self.max_threads)
self.reward_params = self.freqai_info['rl_config']['model_reward_parameters']
self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env()
self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env()
self.eval_callback: Optional[EvalCallback] = None
self.model_type = self.freqai_info['rl_config']['model_type']
self.rl_config = self.freqai_info['rl_config']
@@ -431,7 +431,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
return 0.
def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int,
seed: int, train_df: DataFrame, price: DataFrame,
env_info: Dict[str, Any] = {}) -> Callable:
"""