mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-12-19 06:11:15 +00:00
add tensorboard category
This commit is contained in:
@@ -137,7 +137,8 @@ class BaseEnvironment(gym.Env):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def tensorboard_log(self, metric: str, value: Union[int, float] = 1, inc: bool = True):
|
||||
def tensorboard_log(self, metric: str, value: Optional[Union[int, float]] = None,
|
||||
category: str = "custom"):
|
||||
"""
|
||||
Function builds the tensorboard_metrics dictionary
|
||||
to be parsed by the TensorboardCallback. This
|
||||
@@ -149,17 +150,23 @@ class BaseEnvironment(gym.Env):
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
if not self._is_valid(action):
|
||||
self.tensorboard_log("is_valid")
|
||||
self.tensorboard_log("invalid")
|
||||
return -2
|
||||
|
||||
:param metric: metric to be tracked and incremented
|
||||
:param value: value to increment `metric` by
|
||||
:param inc: sets whether the `value` is incremented or not
|
||||
:param value: `metric` value
|
||||
:param category: `metric` category
|
||||
"""
|
||||
if not inc or metric not in self.tensorboard_metrics:
|
||||
self.tensorboard_metrics[metric] = value
|
||||
increment = True if not value else False
|
||||
value = 1 if increment else value
|
||||
|
||||
if category not in self.tensorboard_metrics:
|
||||
self.tensorboard_metrics[category] = {}
|
||||
|
||||
if not increment or metric not in self.tensorboard_metrics[category]:
|
||||
self.tensorboard_metrics[category][metric] = value
|
||||
else:
|
||||
self.tensorboard_metrics[metric] += value
|
||||
self.tensorboard_metrics[category][metric] += value
|
||||
|
||||
def reset_tensorboard_log(self):
|
||||
self.tensorboard_metrics = {}
|
||||
|
||||
Reference in New Issue
Block a user