diff --git a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py index e2c0f5fda..b024f58af 100644 --- a/freqtrade/freqai/RL/BaseReinforcementLearningModel.py +++ b/freqtrade/freqai/RL/BaseReinforcementLearningModel.py @@ -23,7 +23,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions -from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback +from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback from freqtrade.persistence import Trade diff --git a/freqtrade/freqai/__init__.py b/freqtrade/freqai/__init__.py index 5fb6e5be0..14353de98 100644 --- a/freqtrade/freqai/__init__.py +++ b/freqtrade/freqai/__init__.py @@ -1,12 +1,13 @@ # ensure users can still use a non-torch freqai version try: - from freqtrade.freqai.tensorboard import TensorBoardCallback, TensorboardLogger + from freqtrade.freqai.tensorboard.tensorboard import TensorBoardCallback, TensorboardLogger TBLogger = TensorboardLogger TBCallback = TensorBoardCallback except ModuleNotFoundError: - from freqtrade.freqai.tensorboard import BaseTensorBoardCallback, BaseTensorboardLogger - TBLogger = BaseTensorboardLogger # type: ignore - TBCallback = BaseTensorBoardCallback # type: ignore + from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback, + BaseTensorboardLogger) + TBLogger = BaseTensorboardLogger + TBCallback = BaseTensorBoardCallback __all__ = ( "TBLogger", diff --git a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py index 73f617027..9f0b2d436 100644 --- a/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py +++ b/freqtrade/freqai/prediction_models/ReinforcementLearner_multiproc.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env -from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback +from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback logger = logging.getLogger(__name__) diff --git a/freqtrade/freqai/tensorboard.py b/freqtrade/freqai/tensorboard.py deleted file mode 100644 index cb536008e..000000000 --- a/freqtrade/freqai/tensorboard.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -from pathlib import Path -from typing import Any - -import xgboost as xgb - - -logger = logging.getLogger(__name__) - - -class BaseTensorboardLogger: - def __init__(self, logdir: str = "tensorboard", id: str = "unique-id"): - logger.warning("Tensorboard is not installed, no logs will be written." - "Use ensure torch is installed, or use the torch/RL docker images") - - def log_scaler(self, tag: str, scalar_value: Any, step: int): - return - - def close(self): - return - - -class BaseTensorBoardCallback(xgb.callback.TrainingCallback): - - def __init__(self, logdir: str = "tensorboard", id: str = "uniqu-id", test_size=1): - logger.warning("Tensorboard is not installed, no logs will be written." - "Use ensure torch is installed, or use the torch/RL docker images") - - def after_iteration( - self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog - ) -> bool: - return False - - def after_training(self, model): - return model - - -class TensorboardLogger(BaseTensorboardLogger): - def __init__(self, logdir: Path = Path("tensorboard")): - from torch.utils.tensorboard import SummaryWriter - self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard") - - def log_scalar(self, tag: str, scalar_value: Any, step: int): - self.writer.add_scalar(tag, scalar_value, step) - - def close(self): - self.writer.flush() - self.writer.close() - - -class TensorBoardCallback(BaseTensorBoardCallback): - - def __init__(self, logdir: Path = Path("tensorboard")): - from torch.utils.tensorboard import SummaryWriter - self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard") - - def after_iteration( - self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog - ) -> bool: - if not evals_log: - return False - - for data, metric in evals_log.items(): - for metric_name, log in metric.items(): - score = log[-1][0] if isinstance(log[-1], tuple) else log[-1] - if data == "train": - self.writer.add_scalar("train_loss", score**2, epoch) - else: - self.writer.add_scalar("valid_loss", score**2, epoch) - - return False - - def after_training(self, model): - self.writer.flush() - self.writer.close() - - return model diff --git a/freqtrade/freqai/RL/TensorboardCallback.py b/freqtrade/freqai/tensorboard/TensorboardCallback.py similarity index 100% rename from freqtrade/freqai/RL/TensorboardCallback.py rename to freqtrade/freqai/tensorboard/TensorboardCallback.py diff --git a/freqtrade/freqai/tensorboard/base_tensorboard.py b/freqtrade/freqai/tensorboard/base_tensorboard.py new file mode 100644 index 000000000..186658532 --- /dev/null +++ b/freqtrade/freqai/tensorboard/base_tensorboard.py @@ -0,0 +1,35 @@ +import logging +from pathlib import Path +from typing import Any + +import xgboost as xgb + + +logger = logging.getLogger(__name__) + + +class BaseTensorboardLogger: + def __init__(self, logdir: Path): + logger.warning("Tensorboard is not installed, no logs will be written." + "Ensure torch is installed, or use the torch/RL docker images") + + def log_scaler(self, tag: str, scalar_value: Any, step: int): + return + + def close(self): + return + + +class BaseTensorBoardCallback(xgb.callback.TrainingCallback): + + def __init__(self, logdir: Path): + logger.warning("Tensorboard is not installed, no logs will be written." + "Ensure torch is installed, or use the torch/RL docker images") + + def after_iteration( + self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog + ) -> bool: + return False + + def after_training(self, model): + return model diff --git a/freqtrade/freqai/tensorboard/tensorboard.py b/freqtrade/freqai/tensorboard/tensorboard.py new file mode 100644 index 000000000..f9070be6e --- /dev/null +++ b/freqtrade/freqai/tensorboard/tensorboard.py @@ -0,0 +1,52 @@ +import logging +from pathlib import Path +from typing import Any + +from torch.utils.tensorboard import SummaryWriter +from xgboost import callback + +from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback, + BaseTensorboardLogger) + + +logger = logging.getLogger(__name__) + + +class TensorboardLogger(BaseTensorboardLogger): + def __init__(self, logdir: Path): + self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard") + + def log_scalar(self, tag: str, scalar_value: Any, step: int): + self.writer.add_scalar(tag, scalar_value, step) + + def close(self): + self.writer.flush() + self.writer.close() + + +class TensorBoardCallback(BaseTensorBoardCallback): + + def __init__(self, logdir: Path): + self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard") + + def after_iteration( + self, model, epoch: int, evals_log: callback.TrainingCallback.EvalsLog + ) -> bool: + if not evals_log: + return False + + for data, metric in evals_log.items(): + for metric_name, log in metric.items(): + score = log[-1][0] if isinstance(log[-1], tuple) else log[-1] + if data == "train": + self.writer.add_scalar("train_loss", score**2, epoch) + else: + self.writer.add_scalar("valid_loss", score**2, epoch) + + return False + + def after_training(self, model): + self.writer.flush() + self.writer.close() + + return model diff --git a/tests/freqai/test_freqai_datakitchen.py b/tests/freqai/test_freqai_datakitchen.py index c9d3a973c..cbc4acd18 100644 --- a/tests/freqai/test_freqai_datakitchen.py +++ b/tests/freqai/test_freqai_datakitchen.py @@ -188,7 +188,7 @@ def test_get_full_model_path(mocker, freqai_conf, model): data_load_timerange = TimeRange.parse_timerange("20180110-20180130") new_timerange = TimeRange.parse_timerange("20180120-20180130") - + freqai.dk.set_paths('ADA/BTC', None) freqai.extract_data_and_train_model( new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 5291185f0..a2e4f182a 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -282,6 +282,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog) df[f'%-constant_{i}'] = i metadata = {"pair": "LTC/BTC"} + freqai.dk.set_paths('LTC/BTC', None) freqai.start_backtesting(df, metadata, freqai.dk, strategy) model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()] @@ -439,6 +440,7 @@ def test_principal_component_analysis(mocker, freqai_conf): data_load_timerange = TimeRange.parse_timerange("20180110-20180130") new_timerange = TimeRange.parse_timerange("20180120-20180130") + freqai.dk.set_paths('ADA/BTC', None) freqai.extract_data_and_train_model( new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) @@ -472,6 +474,7 @@ def test_plot_feature_importance(mocker, freqai_conf): data_load_timerange = TimeRange.parse_timerange("20180110-20180130") new_timerange = TimeRange.parse_timerange("20180120-20180130") + freqai.dk.set_paths('ADA/BTC', None) freqai.extract_data_and_train_model( new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)