Merge pull request #9303 from freqtrade/fix/progressbarCallback

Improve freqAI RL (error) behavior
This commit is contained in:
Matthias
2023-10-17 18:01:31 +02:00
committed by GitHub
3 changed files with 24 additions and 9 deletions

View File

@@ -159,7 +159,7 @@ class BaseEnvironment(gym.Env):
function is designed for tracking incremented objects,
events, actions inside the training environment.
For example, a user can call this to track the
frequency of occurence of an `is_valid` call in
frequency of occurrence of an `is_valid` call in
their `calculate_reward()`:
def calculate_reward(self, action: int) -> float:

View File

@@ -1,8 +1,9 @@
import logging
from pathlib import Path
from typing import Any, Dict, Type
from typing import Any, Dict, List, Optional, Type
import torch as th
from stable_baselines3.common.callbacks import ProgressBarCallback
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
@@ -73,19 +74,27 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
'trained agent.')
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
progressbar_callback: Optional[ProgressBarCallback] = None
if self.rl_config.get('progress_bar', False):
progressbar_callback = ProgressBarCallback()
callbacks.insert(0, progressbar_callback)
model.learn(
total_timesteps=int(total_timesteps),
callback=[self.eval_callback, self.tensorboard_callback],
progress_bar=self.rl_config.get('progress_bar', False)
)
try:
model.learn(
total_timesteps=int(total_timesteps),
callback=callbacks,
)
finally:
if progressbar_callback:
progressbar_callback.on_training_end()
if Path(dk.data_path / "best_model.zip").is_file():
logger.info('Callback found a best model.')
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
return best_model
logger.info('Couldnt find best model, using final model instead.')
logger.info("Couldn't find best model, using final model instead.")
return model

View File

@@ -49,7 +49,13 @@ class TensorboardCallback(BaseCallback):
local_info = self.locals["infos"][0]
if self.training_env is None:
return True
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
if hasattr(self.training_env, 'envs'):
tensorboard_metrics = self.training_env.envs[0].unwrapped.tensorboard_metrics
else:
# For RL-multiproc - usage of [0] might need to be evaluated
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
for metric in local_info:
if metric not in ["episode", "terminal_observation"]: