mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-02-09 07:40:40 +00:00
Properly close out progressbarCallback
based on suggestions provided in https://github.com/DLR-RM/stable-baselines3/issues/1645
This commit is contained in:
@@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
import torch as th
|
||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||
@@ -73,12 +74,19 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.set_env(self.train_env)
|
||||
callbacks = [self.eval_callback, self.tensorboard_callback]
|
||||
use_progressbar = self.rl_config.get('progress_bar', False)
|
||||
if use_progressbar:
|
||||
callbacks.insert(0, ProgressBarCallback())
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=[self.eval_callback, self.tensorboard_callback],
|
||||
progress_bar=self.rl_config.get('progress_bar', False)
|
||||
)
|
||||
try:
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=callbacks,
|
||||
)
|
||||
finally:
|
||||
if use_progressbar:
|
||||
callbacks[0].on_training_end()
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
logger.info('Callback found a best model.')
|
||||
|
||||
Reference in New Issue
Block a user