mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-12-19 06:11:15 +00:00
add ability to integrate state info or not, and prevent state info integration during backtesting
This commit is contained in:
@@ -35,6 +35,7 @@ class BaseEnvironment(gym.Env):
|
||||
id: str = 'baseenv-1', seed: int = 1, config: dict = {}):
|
||||
|
||||
self.rl_config = config['freqai']['rl_config']
|
||||
self.add_state_info = self.rl_config.get('add_state_info', False)
|
||||
self.id = id
|
||||
self.seed(seed)
|
||||
self.reset_env(df, prices, window_size, reward_kwargs, starting_point)
|
||||
@@ -58,7 +59,11 @@ class BaseEnvironment(gym.Env):
|
||||
self.fee = 0.0015
|
||||
|
||||
# # spaces
|
||||
self.shape = (window_size, self.signal_features.shape[1] + 3)
|
||||
if self.add_state_info:
|
||||
self.total_features = self.signal_features.shape[1] + 3
|
||||
else:
|
||||
self.total_features = self.signal_features.shape[1]
|
||||
self.shape = (window_size, self.total_features)
|
||||
self.set_action_space()
|
||||
self.observation_space = spaces.Box(
|
||||
low=-1, high=1, shape=self.shape, dtype=np.float32)
|
||||
@@ -126,15 +131,20 @@ class BaseEnvironment(gym.Env):
|
||||
"""
|
||||
features_window = self.signal_features[(
|
||||
self._current_tick - self.window_size):self._current_tick]
|
||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
||||
columns=['current_profit_pct', 'position', 'trade_duration'],
|
||||
index=features_window.index)
|
||||
if self.add_state_info:
|
||||
features_and_state = DataFrame(np.zeros((len(features_window), 3)),
|
||||
columns=['current_profit_pct',
|
||||
'position',
|
||||
'trade_duration'],
|
||||
index=features_window.index)
|
||||
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
features_and_state['trade_duration'] = self.get_trade_duration()
|
||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||
return features_and_state
|
||||
features_and_state['current_profit_pct'] = self.get_unrealized_profit()
|
||||
features_and_state['position'] = self._position.value
|
||||
features_and_state['trade_duration'] = self.get_trade_duration()
|
||||
features_and_state = pd.concat([features_window, features_and_state], axis=1)
|
||||
return features_and_state
|
||||
else:
|
||||
return features_window
|
||||
|
||||
def get_trade_duration(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user