From d88a0dbf82bd180e66b53cca2bc0781179de42a9 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sun, 21 Aug 2022 19:58:36 +0200 Subject: [PATCH] add sb3_contrib models to the available agents. include sb3_contrib in requirements. --- freqtrade/freqai/RL/Base5ActionRLEnv.py | 2 - .../RL/BaseReinforcementLearningModel.py | 54 +++++++++++-------- requirements-freqai.txt | 4 +- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/freqtrade/freqai/RL/Base5ActionRLEnv.py b/freqtrade/freqai/RL/Base5ActionRLEnv.py index b2aeef73b..94de259a9 100644 --- a/freqtrade/freqai/RL/Base5ActionRLEnv.py +++ b/freqtrade/freqai/RL/Base5ActionRLEnv.py @@ -223,12 +223,10 @@ class Base5ActionRLEnv(gym.Env): (action == Actions.Neutral.value and self._position == Positions.Long) or (action == Actions.Short_enter.value and self._position == Positions.Short) or (action == Actions.Short_enter.value and self._position == Positions.Long) or - # (action == Actions.Short_exit.value and self._position == Positions.Short) or (action == Actions.Short_exit.value and self._position == Positions.Long) or (action == Actions.Short_exit.value and self._position == Positions.Neutral) or (action == Actions.Long_enter.value and self._position == Positions.Long) or (action == Actions.Long_enter.value and self._position == Positions.Short) or - # (action == Actions.Long_exit.value and self._position == Positions.Long) or (action == Actions.Long_exit.value and self._position == Positions.Short) or (action == Actions.Long_exit.value and self._position == Positions.Neutral)) diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index a0d5425d3..bb858f3cf 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -6,6 +6,7 @@ import numpy.typing as npt import pandas as pd from pandas import DataFrame from abc import abstractmethod +from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.RL.Base5ActionRLEnv import Base5ActionRLEnv, Actions, Positions @@ -21,6 +22,9 @@ logger = logging.getLogger(__name__) torch.multiprocessing.set_sharing_strategy('file_system') +SB3_MODELS = ['PPO', 'A2C', 'DQN', 'TD3', 'SAC'] +SB3_CONTRIB_MODELS = ['TRPO', 'ARS'] + class BaseReinforcementLearningModel(IFreqaiModel): """ @@ -34,9 +38,19 @@ class BaseReinforcementLearningModel(IFreqaiModel): self.train_env: Base5ActionRLEnv = None self.eval_env: Base5ActionRLEnv = None self.eval_callback: EvalCallback = None - mod = __import__('stable_baselines3', fromlist=[ - self.freqai_info['rl_config']['model_type']]) - self.MODELCLASS = getattr(mod, self.freqai_info['rl_config']['model_type']) + self.model_type = self.freqai_info['rl_config']['model_type'] + if self.model_type in SB3_MODELS: + import_str = 'stable_baselines3' + elif self.model_type in SB3_CONTRIB_MODELS: + import_str = 'sb3_contrib' + else: + raise OperationalException(f'{self.model_type} not available in stable_baselines3 or ' + f'sb3_contrib. please choose one of {SB3_MODELS} or ' + f'{SB3_CONTRIB_MODELS}') + + mod = __import__(import_str, fromlist=[ + self.model_type]) + self.MODELCLASS = getattr(mod, self.model_type) self.policy_type = self.freqai_info['rl_config']['policy_type'] def train( @@ -137,7 +151,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): current_profit = current_value / openrate - 1 total_profit = 0 - closed_trades = Trade.get_trades_proxy(pair = pair, is_open=False) + closed_trades = Trade.get_trades_proxy(pair=pair, is_open=False) for trade in closed_trades: total_profit += trade.close_profit @@ -223,6 +237,7 @@ class BaseReinforcementLearningModel(IFreqaiModel): def fit(self, data_dictionary: Dict[str, Any], pair: str = '') -> Any: return + def make_env(env_id: str, rank: int, seed: int, train_df, price, reward_params, window_size, monitor=False) -> Callable: """ @@ -244,6 +259,7 @@ def make_env(env_id: str, rank: int, seed: int, train_df, price, set_random_seed(seed) return _init + class MyRLEnv(Base5ActionRLEnv): """ User can override any function in BaseRLEnv and gym.Env. Here the user @@ -257,26 +273,20 @@ class MyRLEnv(Base5ActionRLEnv): # close long if action == Actions.Long_exit.value and self._position == Positions.Long: - last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) - return float(np.log(current_price) - np.log(last_trade_price)) - - if action == Actions.Long_exit.value and self._position == Positions.Long: - if self.close_trade_profit[-1] > self.profit_aim * self.rr: - last_trade_price = self.add_buy_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_sell_fee(self.prices.iloc[self._current_tick].open) - return float((np.log(current_price) - np.log(last_trade_price)) * 2) + last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open) + current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open) + factor = 1 + if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: + factor = 2 + return float((np.log(current_price) - np.log(last_trade_price)) * factor) # close short if action == Actions.Short_exit.value and self._position == Positions.Short: - last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) - return float(np.log(last_trade_price) - np.log(current_price)) - - if action == Actions.Short_exit.value and self._position == Positions.Short: - if self.close_trade_profit[-1] > self.profit_aim * self.rr: - last_trade_price = self.add_sell_fee(self.prices.iloc[self._last_trade_tick].open) - current_price = self.add_buy_fee(self.prices.iloc[self._current_tick].open) - return float((np.log(last_trade_price) - np.log(current_price)) * 2) + last_trade_price = self.add_entry_fee(self.prices.iloc[self._last_trade_tick].open) + current_price = self.add_exit_fee(self.prices.iloc[self._current_tick].open) + factor = 1 + if self.close_trade_profit and self.close_trade_profit[-1] > self.profit_aim * self.rr: + factor = 2 + return float(np.log(last_trade_price) - np.log(current_price) * factor) return 0. diff --git a/requirements-freqai.txt b/requirements-freqai.txt index 6000f8e0f..de1b6670a 100644 --- a/requirements-freqai.txt +++ b/requirements-freqai.txt @@ -9,4 +9,6 @@ lightgbm==3.3.2 torch==1.12.1 stable-baselines3==1.6.0 gym==0.21.0 -tensorboard==2.9.1 \ No newline at end of file +tensorboard==2.9.1 +optuna==2.10.1 +sb3-contrib==1.6.0 \ No newline at end of file