diff --git a/freqtrade/freqai/RL/Base3ActionRLEnv.py b/freqtrade/freqai/RL/Base3ActionRLEnv.py index c0a7eedaa..538ca3a6a 100644 --- a/freqtrade/freqai/RL/Base3ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base3ActionRLEnv.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from gym import spaces +from gymnasium import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -94,9 +94,12 @@ class Base3ActionRLEnv(BaseEnvironment): observation = self._get_observation() + # user can play with time if they want + truncated = False + self._update_history(info) - return observation, step_reward, self._done, info + return observation, step_reward, self._done, truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/Base4ActionRLEnv.py b/freqtrade/freqai/RL/Base4ActionRLEnv.py index e883136b2..12f10d4fc 100644 --- a/freqtrade/freqai/RL/Base4ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base4ActionRLEnv.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from gym import spaces +from gymnasium import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -96,9 +96,12 @@ class Base4ActionRLEnv(BaseEnvironment): observation = self._get_observation() + # user can play with time if they want + truncated = False + self._update_history(info) - return observation, step_reward, self._done, info + return observation, step_reward, self._done, truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index 816211cc2..35d04f942 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from gym import spaces +from gymnasium import spaces from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment, Positions @@ -101,10 +101,12 @@ class Base5ActionRLEnv(BaseEnvironment): ) observation = self._get_observation() + # user can play with time if they want + truncated = False self._update_history(info) - return observation, step_reward, self._done, info + return observation, step_reward, self._done, truncated, info def is_tradesignal(self, action: int) -> bool: """ diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 7ac77361c..08bb93347 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -4,11 +4,11 @@ from abc import abstractmethod from enum import Enum from typing import Optional, Type, Union -import gym +import gymnasium as gym import numpy as np import pandas as pd -from gym import spaces -from gym.utils import seeding +from gymnasium import spaces +from gymnasium.utils import seeding from pandas import DataFrame @@ -127,6 +127,14 @@ class BaseEnvironment(gym.Env): self.history: dict = {} self.trade_history: list = [] + def get_attr(self, attr: str): + """ + Returns the attribute of the environment + :param attr: attribute to return + :return: attribute + """ + return getattr(self, attr) + @abstractmethod def set_action_space(self): """ @@ -203,7 +211,7 @@ class BaseEnvironment(gym.Env): self.close_trade_profit = [] self._total_unrealized_profit = 1 - return self._get_observation() + return self._get_observation(), self.history @abstractmethod def step(self, action: int): diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e10880f46..e2c0f5fda 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Type, Union -import gym +import gymnasium as gym import numpy as np import numpy.typing as npt import pandas as pd @@ -16,13 +16,13 @@ from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv -from freqtrade.freqai.RL.BaseEnvironment import BaseActions, Positions +from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback from freqtrade.persistence import Trade @@ -46,8 +46,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): 'cpu_count', 1), max(int(self.max_system_threads / 2), 1)) th.set_num_threads(self.max_threads) self.reward_params = self.freqai_info['rl_config']['model_reward_parameters'] - self.train_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() - self.eval_env: Union[SubprocVecEnv, Type[gym.Env]] = gym.Env() + self.train_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() + self.eval_env: Union[VecMonitor, SubprocVecEnv, gym.Env] = gym.Env() self.eval_callback: Optional[EvalCallback] = None self.model_type = self.freqai_info['rl_config']['model_type'] self.rl_config = self.freqai_info['rl_config'] @@ -431,9 +431,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): return 0. -def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, +def make_env(MyRLEnv: Type[BaseEnvironment], env_id: str, rank: int, seed: int, train_df: DataFrame, price: DataFrame, - monitor: bool = False, env_info: Dict[str, Any] = {}) -> Callable: """ Utility function for multiprocessed env. @@ -450,8 +449,7 @@ def make_env(MyRLEnv: Type[gym.Env], env_id: str, rank: int, env = MyRLEnv(df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info) - if monitor: - env = Monitor(env) + return env set_random_seed(seed) return _init diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/RL/TensorboardCallback.py index 7f8c76956..61652c9c6 100644 --- a/freqtrade/freqai/RL/TensorboardCallback.py +++ b/freqtrade/freqai/RL/TensorboardCallback.py @@ -3,8 +3,9 @@ from typing import Any, Dict, Type, Union from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam +from stable_baselines3.common.vec_env import VecEnv -from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment +from freqtrade.freqai.RL.BaseEnvironment import BaseActions class TensorboardCallback(BaseCallback): @@ -12,11 +13,13 @@ class TensorboardCallback(BaseCallback): Custom callback for plotting additional values in tensorboard and episodic summary reports. """ + # Override training_env type to fix type errors + training_env: Union[VecEnv, None] = None + def __init__(self, verbose=1, actions: Type[Enum] = BaseActions): super().__init__(verbose) self.model: Any = None - self.logger = None # type: Any - self.training_env: BaseEnvironment = None # type: ignore + self.logger: Any = None self.actions: Type[Enum] = actions def _on_training_start(self) -> None: @@ -44,6 +47,8 @@ class TensorboardCallback(BaseCallback): def _on_step(self) -> bool: local_info = self.locals["infos"][0] + if self.training_env is None: + return True tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] for metric in local_info: diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 039b6a175..3580963d4 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -242,8 +242,8 @@ class IFreqaiModel(ABC): new_trained_timerange, pair, strategy, dk, data_load_timerange ) except Exception as msg: - logger.warning(f"Training {pair} raised exception {msg.__class__.__name__}. " - f"Message: {msg}, skipping.") + logger.exception(f"Training {pair} raised exception {msg.__class__.__name__}. " + f"Message: {msg}, skipping.") self.train_timer('stop', pair) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index 65990da87..a5c2e12b5 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -1,11 +1,12 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Type import torch as th from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions +from freqtrade.freqai.RL.BaseEnvironment import BaseEnvironment from freqtrade.freqai.RL.BaseReinforcementLearningModel import BaseReinforcementLearningModel @@ -84,7 +85,9 @@ class ReinforcementLearner(BaseReinforcementLearningModel): return model - class MyRLEnv(Base5ActionRLEnv): + MyRLEnv: Type[BaseEnvironment] + + class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef] """ User can override any function in BaseRLEnv and gym.Env. Here the user sets a custom reward based on profit and trade duration. diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index b3b8c40e6..73f617027 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -3,7 +3,7 @@ from typing import Any, Dict from pandas import DataFrame from stable_baselines3.common.callbacks import EvalCallback -from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner @@ -41,22 +41,25 @@ class ReinforcementLearner_multiproc(ReinforcementLearner): env_info = self.pack_env_dict(dk.pair) + eval_freq = len(train_df) // self.max_threads + env_id = "train_env" - self.train_env = SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, - train_df, prices_train, - monitor=True, - env_info=env_info) for i - in range(self.max_threads)]) + self.train_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, env_id, i, 1, + train_df, prices_train, + env_info=env_info) for i + in range(self.max_threads)])) eval_env_id = 'eval_env' - self.eval_env = SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, - test_df, prices_test, - monitor=True, - env_info=env_info) for i - in range(self.max_threads)]) + self.eval_env = VecMonitor(SubprocVecEnv([make_env(self.MyRLEnv, eval_env_id, i, 1, + test_df, prices_test, + env_info=env_info) for i + in range(self.max_threads)])) + self.eval_callback = EvalCallback(self.eval_env, deterministic=True, - render=False, eval_freq=len(train_df), + render=False, eval_freq=eval_freq, best_model_save_path=str(dk.data_path)) + # TENSORBOARD CALLBACK DOES NOT RECOMMENDED TO USE WITH MULTIPLE ENVS, + # IT WILL RETURN FALSE INFORMATIONS, NEVERTHLESS NOT THREAD SAFE WITH SB3!!! actions = self.train_env.env_method("get_actions")[0] self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions) diff --git a/requirements-freqai-rl.txt b/requirements-freqai-rl.txt index f4e1e557b..45ccc40cc 100644 --- a/requirements-freqai-rl.txt +++ b/requirements-freqai-rl.txt @@ -3,10 +3,11 @@ # Required for freqai-rl torch==1.13.1; python_version < '3.11' -stable-baselines3==1.7.0; python_version < '3.11' -sb3-contrib==1.7.0; python_version < '3.11' +#until these branches will be released we can use this +gymnasium==0.28.1 +stable_baselines3==2.0.0a5 +sb3_contrib>=2.0.0a4 # Gym is forced to this version by stable-baselines3. setuptools==65.5.1 # Should be removed when gym is fixed. -gym==0.21; python_version < '3.11' # Progress bar for stable-baselines3 and sb3-contrib tqdm==4.65.0; python_version < '3.11'