mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-01-20 14:00:38 +00:00
Merge pull request #9303 from freqtrade/fix/progressbarCallback
Improve freqAI RL (error) behavior
This commit is contained in:
@@ -159,7 +159,7 @@ class BaseEnvironment(gym.Env):
|
||||
function is designed for tracking incremented objects,
|
||||
events, actions inside the training environment.
|
||||
For example, a user can call this to track the
|
||||
frequency of occurence of an `is_valid` call in
|
||||
frequency of occurrence of an `is_valid` call in
|
||||
their `calculate_reward()`:
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch as th
|
||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
|
||||
@@ -73,19 +74,27 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
||||
'trained agent.')
|
||||
model = self.dd.model_dictionary[dk.pair]
|
||||
model.set_env(self.train_env)
|
||||
callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
|
||||
progressbar_callback: Optional[ProgressBarCallback] = None
|
||||
if self.rl_config.get('progress_bar', False):
|
||||
progressbar_callback = ProgressBarCallback()
|
||||
callbacks.insert(0, progressbar_callback)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=[self.eval_callback, self.tensorboard_callback],
|
||||
progress_bar=self.rl_config.get('progress_bar', False)
|
||||
)
|
||||
try:
|
||||
model.learn(
|
||||
total_timesteps=int(total_timesteps),
|
||||
callback=callbacks,
|
||||
)
|
||||
finally:
|
||||
if progressbar_callback:
|
||||
progressbar_callback.on_training_end()
|
||||
|
||||
if Path(dk.data_path / "best_model.zip").is_file():
|
||||
logger.info('Callback found a best model.')
|
||||
best_model = self.MODELCLASS.load(dk.data_path / "best_model")
|
||||
return best_model
|
||||
|
||||
logger.info('Couldnt find best model, using final model instead.')
|
||||
logger.info("Couldn't find best model, using final model instead.")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -49,7 +49,13 @@ class TensorboardCallback(BaseCallback):
|
||||
local_info = self.locals["infos"][0]
|
||||
if self.training_env is None:
|
||||
return True
|
||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||
|
||||
if hasattr(self.training_env, 'envs'):
|
||||
tensorboard_metrics = self.training_env.envs[0].unwrapped.tensorboard_metrics
|
||||
|
||||
else:
|
||||
# For RL-multiproc - usage of [0] might need to be evaluated
|
||||
tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
|
||||
|
||||
for metric in local_info:
|
||||
if metric not in ["episode", "terminal_observation"]:
|
||||
|
||||
Reference in New Issue
Block a user