diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py index 8f5fe4e03..5ec917719 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearningTDQN.py @@ -72,9 +72,11 @@ class ReinforcementLearningTDQN(BaseReinforcementLearningModel): class Actions(Enum): - Short = 0 - Long = 1 - Neutral = 2 + Neutral = 0 + Long_buy = 1 + Long_sell = 2 + Short_buy = 3 + Short_sell = 4 class Positions(Enum): @@ -179,31 +181,36 @@ class MyRLEnv(BaseRLEnv): self.total_reward += step_reward trade_type = None - if self.is_tradesignal(action): # exclude 3 case not trade + if self.is_tradesignal(action): # exclude 3 case not trade # Update position """ - Action: Neutral, position: Long -> Close Long - Action: Neutral, position: Short -> Close Short - - Action: Long, position: Neutral -> Open Long + Action: Neutral, position: Long -> Close Long + Action: Neutral, position: Short -> Close Short + + Action: Long, position: Neutral -> Open Long Action: Long, position: Short -> Close Short and Open Long - - Action: Short, position: Neutral -> Open Short + + Action: Short, position: Neutral -> Open Short Action: Short, position: Long -> Close Long and Open Short """ - - temp_position = self._position + if action == Actions.Neutral.value: self._position = Positions.Neutral trade_type = "neutral" - elif action == Actions.Long.value: + elif action == Actions.Long_buy.value: self._position = Positions.Long trade_type = "long" - elif action == Actions.Short.value: + elif action == Actions.Short_buy.value: self._position = Positions.Short trade_type = "short" + elif action == Actions.Long_sell.value: + self._position = Positions.Neutral + trade_type = "neutral" + elif action == Actions.Short_sell.value: + self._position = Positions.Neutral + trade_type = "neutral" else: - print("case not define") + print("case not defined") # Update last trade tick self._last_trade_tick = self._current_tick @@ -250,23 +257,33 @@ class MyRLEnv(BaseRLEnv): return 0. def is_tradesignal(self, action): - # trade signal + # trade signal """ not trade signal is : - Action: Neutral, position: Neutral -> Nothing + Action: Neutral, position: Neutral -> Nothing Action: Long, position: Long -> Hold Long Action: Short, position: Short -> Hold Short """ - return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) - or (action == Actions.Short.value and self._position == Positions.Short) - or (action == Actions.Long.value and self._position == Positions.Long)) + return not ((action == Actions.Neutral.value and self._position == Positions.Neutral) or + (action == Actions.Short_buy.value and self._position == Positions.Short) or + (action == Actions.Short_sell.value and self._position == Positions.Short) or + (action == Actions.Short_buy.value and self._position == Positions.Long) or + (action == Actions.Short_sell.value and self._position == Positions.Long) or - def _is_trade(self, action: Actions): - return ((action == Actions.Long.value and self._position == Positions.Short) or - (action == Actions.Short.value and self._position == Positions.Long) or - (action == Actions.Neutral.value and self._position == Positions.Long) or - (action == Actions.Neutral.value and self._position == Positions.Short) - ) + (action == Actions.Long_buy.value and self._position == Positions.Long) or + (action == Actions.Long_sell.value and self._position == Positions.Long) or + (action == Actions.Long_buy.value and self._position == Positions.Short) or + (action == Actions.Long_sell.value and self._position == Positions.Short)) + + def _is_trade(self, action): + return ((action == Actions.Long_buy.value and self._position == Positions.Short) or + (action == Actions.Short_buy.value and self._position == Positions.Long) or + (action == Actions.Neutral.value and self._position == Positions.Long) or + (action == Actions.Neutral.value and self._position == Positions.Short) or + + (action == Actions.Neutral.Short_sell and self._position == Positions.Long) or + (action == Actions.Neutral.Long_sell and self._position == Positions.Short) + ) def is_hold(self, action): return ((action == Actions.Short.value and self._position == Positions.Short)