Improve formatting

This commit is contained in:
Matthias
2023-04-17 20:27:18 +02:00
parent c055f82e9a
commit 3fb5cd3df6
5 changed files with 18 additions and 20 deletions

View File

@@ -94,12 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment):
observation = self._get_observation()
#user can play with time if they want
# user can play with time if they want
truncated = False
self._update_history(info)
return observation, step_reward, self._done,truncated, info
return observation, step_reward, self._done, truncated, info
def is_tradesignal(self, action: int) -> bool:
"""

View File

@@ -96,12 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment):
observation = self._get_observation()
#user can play with time if they want
# user can play with time if they want
truncated = False
self._update_history(info)
return observation, step_reward, self._done,truncated, info
return observation, step_reward, self._done, truncated, info
def is_tradesignal(self, action: int) -> bool:
"""

View File

@@ -101,12 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment):
)
observation = self._get_observation()
#user can play with time if they want
# user can play with time if they want
truncated = False
self._update_history(info)
return observation, step_reward, self._done,truncated, info
return observation, step_reward, self._done, truncated, info
def is_tradesignal(self, action: int) -> bool:
"""

View File

@@ -449,7 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int,
env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank,
**env_info)
return env
set_random_seed(seed)
return _init

View File

@@ -3,8 +3,8 @@ from typing import Any, Dict
from pandas import DataFrame
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env import VecMonitor
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
@@ -45,23 +45,21 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
env_id = "train_env"
self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1,
train_df, prices_train,
env_info=env_info) for i
in range(self.max_threads)]))
train_df, prices_train,
env_info=env_info) for i
in range(self.max_threads)]))
eval_env_id = 'eval_env'
self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1,
test_df, prices_test,
env_info=env_info) for i
in range(self.max_threads)]))
test_df, prices_test,
env_info=env_info) for i
in range(self.max_threads)]))
self.eval_callback = EvalCallback(self.eval_env, deterministic=True,
render=False, eval_freq=eval_freq,
best_model_save_path=str(dk.data_path))
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
# TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS,
# IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!!
actions = self.train_env.env_method("get_actions")[0]
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)