diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 17d82a3ba..ef1c02a3b 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -45,7 +45,7 @@ class BaseEnvironment(gym.Env): def __init__(self, df: DataFrame = DataFrame(), prices: DataFrame = DataFrame(), reward_kwargs: dict = {}, window_size=10, starting_point=True, id: str = 'baseenv-1', seed: int = 1, config: dict = {}, live: bool = False, - fee: float = 0.0015): + fee: float = 0.0015, can_short: bool = False): """ Initializes the training/eval environment. :param df: dataframe of features @@ -58,6 +58,7 @@ class BaseEnvironment(gym.Env): :param config: Typical user configuration file :param live: Whether or not this environment is active in dry/live/backtesting :param fee: The fee to use for environmental interactions. + :param can_short: Whether or not the environment can short """ self.config = config self.rl_config = config['freqai']['rl_config'] @@ -73,6 +74,7 @@ class BaseEnvironment(gym.Env): # set here to default 5Ac, but all children envs can override this self.actions: Type[Enum] = BaseActions self.tensorboard_metrics: dict = {} + self.can_short = can_short self.live = live if not self.live and self.add_state_info: self.add_state_info = False diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index d7e3a3cad..af0726c0b 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -165,7 +165,8 @@ class BaseReinforcementLearningModel(IFreqaiModel): env_info = {"window_size": self.CONV_WIDTH, "reward_kwargs": self.reward_params, "config": self.config, - "live": self.live} + "live": self.live, + "can_short": self.can_short} if self.data_provider: env_info["fee"] = self.data_provider._exchange \ .get_fee(symbol=self.data_provider.current_whitelist()[0]) # type: ignore diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 34780f930..bbae7453f 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -133,6 +133,7 @@ class IFreqaiModel(ABC): self.live = strategy.dp.runmode in (RunMode.DRY_RUN, RunMode.LIVE) self.dd.set_pair_dict_info(metadata) self.data_provider = strategy.dp + self.can_short = strategy.can_short if self.live: self.inference_timer('start')