diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner.py b/freqtrade/freqai/prediction_models/ReinforcementLearner.py index d4c1881a6..fbf12008a 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, Dict, Type +from typing import Any, Dict, List, Optional, Type import torch as th from stable_baselines3.common.callbacks import ProgressBarCallback @@ -74,10 +74,11 @@ class ReinforcementLearner(BaseReinforcementLearningModel): 'trained agent.') model = self.dd.model_dictionary[dk.pair] model.set_env(self.train_env) - callbacks = [self.eval_callback, self.tensorboard_callback] - use_progressbar = self.rl_config.get('progress_bar', False) - if use_progressbar: - callbacks.insert(0, ProgressBarCallback()) + callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback] + progressbar_callback: Optional[ProgressBarCallback] = None + if self.rl_config.get('progress_bar', False): + progressbar_callback = ProgressBarCallback() + callbacks.insert(0, progressbar_callback) try: model.learn( @@ -85,8 +86,8 @@ class ReinforcementLearner(BaseReinforcementLearningModel): callback=callbacks, ) finally: - if use_progressbar: - callbacks[0].on_training_end() + if progressbar_callback: + progressbar_callback.on_training_end() if Path(dk.data_path / "best_model.zip").is_file(): logger.info('Callback found a best model.')