fix the import logic, fix tests, put all tensorboard in a single folder

This commit is contained in:
robcaulk
2023-05-12 07:56:44 +00:00
parent 6df5cb8878
commit 692fa390c6
9 changed files with 98 additions and 84 deletions

View File

@@ -23,7 +23,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
from freqtrade.persistence import Trade

View File

@@ -1,12 +1,13 @@
# ensure users can still use a non-torch freqai version
try:
from freqtrade.freqai.tensorboard import TensorBoardCallback, TensorboardLogger
from freqtrade.freqai.tensorboard.tensorboard import TensorBoardCallback, TensorboardLogger
TBLogger = TensorboardLogger
TBCallback = TensorBoardCallback
except ModuleNotFoundError:
from freqtrade.freqai.tensorboard import BaseTensorBoardCallback, BaseTensorboardLogger
TBLogger = BaseTensorboardLogger # type: ignore
TBCallback = BaseTensorBoardCallback # type: ignore
from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback,
BaseTensorboardLogger)
TBLogger = BaseTensorboardLogger
TBCallback = BaseTensorBoardCallback
__all__ = (
"TBLogger",

View File

@@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
logger = logging.getLogger(__name__)

View File

@@ -1,77 +0,0 @@
import logging
from pathlib import Path
from typing import Any
import xgboost as xgb
logger = logging.getLogger(__name__)
class BaseTensorboardLogger:
def __init__(self, logdir: str = "tensorboard", id: str = "unique-id"):
logger.warning("Tensorboard is not installed, no logs will be written."
"Use ensure torch is installed, or use the torch/RL docker images")
def log_scaler(self, tag: str, scalar_value: Any, step: int):
return
def close(self):
return
class BaseTensorBoardCallback(xgb.callback.TrainingCallback):
def __init__(self, logdir: str = "tensorboard", id: str = "uniqu-id", test_size=1):
logger.warning("Tensorboard is not installed, no logs will be written."
"Use ensure torch is installed, or use the torch/RL docker images")
def after_iteration(
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
) -> bool:
return False
def after_training(self, model):
return model
class TensorboardLogger(BaseTensorboardLogger):
def __init__(self, logdir: Path = Path("tensorboard")):
from torch.utils.tensorboard import SummaryWriter
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
def log_scalar(self, tag: str, scalar_value: Any, step: int):
self.writer.add_scalar(tag, scalar_value, step)
def close(self):
self.writer.flush()
self.writer.close()
class TensorBoardCallback(BaseTensorBoardCallback):
def __init__(self, logdir: Path = Path("tensorboard")):
from torch.utils.tensorboard import SummaryWriter
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
def after_iteration(
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
) -> bool:
if not evals_log:
return False
for data, metric in evals_log.items():
for metric_name, log in metric.items():
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
if data == "train":
self.writer.add_scalar("train_loss", score**2, epoch)
else:
self.writer.add_scalar("valid_loss", score**2, epoch)
return False
def after_training(self, model):
self.writer.flush()
self.writer.close()
return model

View File

@@ -0,0 +1,35 @@
import logging
from pathlib import Path
from typing import Any
import xgboost as xgb
logger = logging.getLogger(__name__)
class BaseTensorboardLogger:
def __init__(self, logdir: Path):
logger.warning("Tensorboard is not installed, no logs will be written."
"Ensure torch is installed, or use the torch/RL docker images")
def log_scaler(self, tag: str, scalar_value: Any, step: int):
return
def close(self):
return
class BaseTensorBoardCallback(xgb.callback.TrainingCallback):
def __init__(self, logdir: Path):
logger.warning("Tensorboard is not installed, no logs will be written."
"Ensure torch is installed, or use the torch/RL docker images")
def after_iteration(
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
) -> bool:
return False
def after_training(self, model):
return model

View File

@@ -0,0 +1,52 @@
import logging
from pathlib import Path
from typing import Any
from torch.utils.tensorboard import SummaryWriter
from xgboost import callback
from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback,
BaseTensorboardLogger)
logger = logging.getLogger(__name__)
class TensorboardLogger(BaseTensorboardLogger):
def __init__(self, logdir: Path):
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
def log_scalar(self, tag: str, scalar_value: Any, step: int):
self.writer.add_scalar(tag, scalar_value, step)
def close(self):
self.writer.flush()
self.writer.close()
class TensorBoardCallback(BaseTensorBoardCallback):
def __init__(self, logdir: Path):
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
def after_iteration(
self, model, epoch: int, evals_log: callback.TrainingCallback.EvalsLog
) -> bool:
if not evals_log:
return False
for data, metric in evals_log.items():
for metric_name, log in metric.items():
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
if data == "train":
self.writer.add_scalar("train_loss", score**2, epoch)
else:
self.writer.add_scalar("valid_loss", score**2, epoch)
return False
def after_training(self, model):
self.writer.flush()
self.writer.close()
return model

View File

@@ -188,7 +188,7 @@ def test_get_full_model_path(mocker, freqai_conf, model):
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
new_timerange = TimeRange.parse_timerange("20180120-20180130")
freqai.dk.set_paths('ADA/BTC', None)
freqai.extract_data_and_train_model(
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)

View File

@@ -282,6 +282,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
df[f'%-constant_{i}'] = i
metadata = {"pair": "LTC/BTC"}
freqai.dk.set_paths('LTC/BTC', None)
freqai.start_backtesting(df, metadata, freqai.dk, strategy)
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
@@ -439,6 +440,7 @@ def test_principal_component_analysis(mocker, freqai_conf):
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
new_timerange = TimeRange.parse_timerange("20180120-20180130")
freqai.dk.set_paths('ADA/BTC', None)
freqai.extract_data_and_train_model(
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
@@ -472,6 +474,7 @@ def test_plot_feature_importance(mocker, freqai_conf):
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
new_timerange = TimeRange.parse_timerange("20180120-20180130")
freqai.dk.set_paths('ADA/BTC', None)
freqai.extract_data_and_train_model(
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)