diff --git a/freqtrade/freqai/RL/BaseEnvironment.py b/freqtrade/freqai/RL/BaseEnvironment.py index 54502c869..b8548dd16 100644 --- a/freqtrade/freqai/RL/BaseEnvironment.py +++ b/freqtrade/freqai/RL/BaseEnvironment.py @@ -159,7 +159,7 @@ class BaseEnvironment(gym.Env): function is designed for tracking incremented objects, events, actions inside the training environment. For example, a user can call this to track the - frequency of occurence of an `is_valid` call in + frequency of occurrence of an `is_valid` call in their `calculate_reward()`: def calculate_reward(self, action: int) -> float: diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index a11decc92..fbf12008a 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -1,8 +1,9 @@ import logging from pathlib import Path -from typing import Any, Dict, Type +from typing import Any, Dict, List, Optional, Type import torch as th +from stable_baselines3.common.callbacks import ProgressBarCallback from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions @@ -73,19 +74,27 @@ class ReinforcementLearner(BaseReinforcementLearningModel): 'trained agent.') model = self.dd.model_dictionary[dk.pair] model.set_env(self.train_env) + callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback] + progressbar_callback: Optional[ProgressBarCallback] = None + if self.rl_config.get('progress_bar', False): + progressbar_callback = ProgressBarCallback() + callbacks.insert(0, progressbar_callback) - model.learn( - total_timesteps=int(total_timesteps), - callback=[self.eval_callback, self.tensorboard_callback], - progress_bar=self.rl_config.get('progress_bar', False) - ) + try: + model.learn( + total_timesteps=int(total_timesteps), + callback=callbacks, + ) + finally: + if progressbar_callback: + progressbar_callback.on_training_end() if Path(dk.data_path / "best_model.zip").is_file(): logger.info('Callback found a best model.') best_model = self.MODELCLASS.load(dk.data_path / "best_model") return best_model - logger.info('Couldnt find best model, using final model instead.') + logger.info("Couldn't find best model, using final model instead.") return model diff --git a/freqtrade/freqai/tensorboard/TensorboardCallback.py b/freqtrade/freqai/tensorboard/TensorboardCallback.py index 61652c9c6..2be917616 100644 --- a/freqtrade/freqai/tensorboard/TensorboardCallback.py +++ b/freqtrade/freqai/tensorboard/TensorboardCallback.py @@ -49,7 +49,13 @@ class TensorboardCallback(BaseCallback): local_info = self.locals["infos"][0] if self.training_env is None: return True - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] + + if hasattr(self.training_env, 'envs'): + tensorboard_metrics = self.training_env.envs[0].unwrapped.tensorboard_metrics + + else: + # For RL-multiproc - usage of [0] might need to be evaluated + tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] for metric in local_info: if metric not in ["episode", "terminal_observation"]: