From 646dd63faf89cbad809905349c56c31678435765 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sun, 15 Oct 2023 10:41:07 +0200 Subject: [PATCH] Properly close out progressbarCallback based on suggestions provided in https://github.com/DLR-RM/stable-baselines3/issues/1645 --- .../prediction_models/ReinforcementLearner.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index d9a11a7a8..d4c1881a6 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Dict, Type import torch as th +from stable_baselines3.common.callbacks import ProgressBarCallback from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions @@ -73,12 +74,19 @@ class ReinforcementLearner(BaseReinforcementLearningModel): 'trained agent.') model = self.dd.model_dictionary[dk.pair] model.set_env(self.train_env) + callbacks = [self.eval_callback, self.tensorboard_callback] + use_progressbar = self.rl_config.get('progress_bar', False) + if use_progressbar: + callbacks.insert(0, ProgressBarCallback()) - model.learn( - total_timesteps=int(total_timesteps), - callback=[self.eval_callback, self.tensorboard_callback], - progress_bar=self.rl_config.get('progress_bar', False) - ) + try: + model.learn( + total_timesteps=int(total_timesteps), + callback=callbacks, + ) + finally: + if use_progressbar: + callbacks[0].on_training_end() if Path(dk.data_path / "best_model.zip").is_file(): logger.info('Callback found a best model.')