mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
fix the import logic, fix tests, put all tensorboard in a single folder
This commit is contained in:
@@ -23,7 +23,7 @@ from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv
|
||||
from freqtrade.freqai.RL.BaseEnvironment import BaseActions, BaseEnvironment, Positions
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.persistence import Trade
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# ensure users can still use a non-torch freqai version
|
||||
try:
|
||||
from freqtrade.freqai.tensorboard import TensorBoardCallback, TensorboardLogger
|
||||
from freqtrade.freqai.tensorboard.tensorboard import TensorBoardCallback, TensorboardLogger
|
||||
TBLogger = TensorboardLogger
|
||||
TBCallback = TensorBoardCallback
|
||||
except ModuleNotFoundError:
|
||||
from freqtrade.freqai.tensorboard import BaseTensorBoardCallback, BaseTensorboardLogger
|
||||
TBLogger = BaseTensorboardLogger # type: ignore
|
||||
TBCallback = BaseTensorBoardCallback # type: ignore
|
||||
from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback,
|
||||
BaseTensorboardLogger)
|
||||
TBLogger = BaseTensorboardLogger
|
||||
TBCallback = BaseTensorBoardCallback
|
||||
|
||||
__all__ = (
|
||||
"TBLogger",
|
||||
|
||||
@@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.prediction_models.ReinforcementLearner import ReinforcementLearner
|
||||
from freqtrade.freqai.RL.BaseReinforcementLearningModel import make_env
|
||||
from freqtrade.freqai.RL.TensorboardCallback import TensorboardCallback
|
||||
from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTensorboardLogger:
|
||||
def __init__(self, logdir: str = "tensorboard", id: str = "unique-id"):
|
||||
logger.warning("Tensorboard is not installed, no logs will be written."
|
||||
"Use ensure torch is installed, or use the torch/RL docker images")
|
||||
|
||||
def log_scaler(self, tag: str, scalar_value: Any, step: int):
|
||||
return
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
|
||||
class BaseTensorBoardCallback(xgb.callback.TrainingCallback):
|
||||
|
||||
def __init__(self, logdir: str = "tensorboard", id: str = "uniqu-id", test_size=1):
|
||||
logger.warning("Tensorboard is not installed, no logs will be written."
|
||||
"Use ensure torch is installed, or use the torch/RL docker images")
|
||||
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
return model
|
||||
|
||||
|
||||
class TensorboardLogger(BaseTensorboardLogger):
|
||||
def __init__(self, logdir: Path = Path("tensorboard")):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
||||
|
||||
def log_scalar(self, tag: str, scalar_value: Any, step: int):
|
||||
self.writer.add_scalar(tag, scalar_value, step)
|
||||
|
||||
def close(self):
|
||||
self.writer.flush()
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class TensorBoardCallback(BaseTensorBoardCallback):
|
||||
|
||||
def __init__(self, logdir: Path = Path("tensorboard")):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
||||
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
|
||||
if data == "train":
|
||||
self.writer.add_scalar("train_loss", score**2, epoch)
|
||||
else:
|
||||
self.writer.add_scalar("valid_loss", score**2, epoch)
|
||||
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
self.writer.flush()
|
||||
self.writer.close()
|
||||
|
||||
return model
|
||||
35
freqtrade/freqai/tensorboard/base_tensorboard.py
Normal file
35
freqtrade/freqai/tensorboard/base_tensorboard.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTensorboardLogger:
|
||||
def __init__(self, logdir: Path):
|
||||
logger.warning("Tensorboard is not installed, no logs will be written."
|
||||
"Ensure torch is installed, or use the torch/RL docker images")
|
||||
|
||||
def log_scaler(self, tag: str, scalar_value: Any, step: int):
|
||||
return
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
|
||||
class BaseTensorBoardCallback(xgb.callback.TrainingCallback):
|
||||
|
||||
def __init__(self, logdir: Path):
|
||||
logger.warning("Tensorboard is not installed, no logs will be written."
|
||||
"Ensure torch is installed, or use the torch/RL docker images")
|
||||
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
return model
|
||||
52
freqtrade/freqai/tensorboard/tensorboard.py
Normal file
52
freqtrade/freqai/tensorboard/tensorboard.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from xgboost import callback
|
||||
|
||||
from freqtrade.freqai.tensorboard.base_tensorboard import (BaseTensorBoardCallback,
|
||||
BaseTensorboardLogger)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorboardLogger(BaseTensorboardLogger):
|
||||
def __init__(self, logdir: Path):
|
||||
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
||||
|
||||
def log_scalar(self, tag: str, scalar_value: Any, step: int):
|
||||
self.writer.add_scalar(tag, scalar_value, step)
|
||||
|
||||
def close(self):
|
||||
self.writer.flush()
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class TensorBoardCallback(BaseTensorBoardCallback):
|
||||
|
||||
def __init__(self, logdir: Path):
|
||||
self.writer: SummaryWriter = SummaryWriter(f"{str(logdir)}/tensorboard")
|
||||
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: callback.TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
|
||||
if data == "train":
|
||||
self.writer.add_scalar("train_loss", score**2, epoch)
|
||||
else:
|
||||
self.writer.add_scalar("valid_loss", score**2, epoch)
|
||||
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
self.writer.flush()
|
||||
self.writer.close()
|
||||
|
||||
return model
|
||||
@@ -188,7 +188,7 @@ def test_get_full_model_path(mocker, freqai_conf, model):
|
||||
|
||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||
|
||||
freqai.dk.set_paths('ADA/BTC', None)
|
||||
freqai.extract_data_and_train_model(
|
||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||
|
||||
|
||||
@@ -282,6 +282,7 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
|
||||
df[f'%-constant_{i}'] = i
|
||||
|
||||
metadata = {"pair": "LTC/BTC"}
|
||||
freqai.dk.set_paths('LTC/BTC', None)
|
||||
freqai.start_backtesting(df, metadata, freqai.dk, strategy)
|
||||
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
|
||||
|
||||
@@ -439,6 +440,7 @@ def test_principal_component_analysis(mocker, freqai_conf):
|
||||
|
||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||
freqai.dk.set_paths('ADA/BTC', None)
|
||||
|
||||
freqai.extract_data_and_train_model(
|
||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||
@@ -472,6 +474,7 @@ def test_plot_feature_importance(mocker, freqai_conf):
|
||||
|
||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||
freqai.dk.set_paths('ADA/BTC', None)
|
||||
|
||||
freqai.extract_data_and_train_model(
|
||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
||||
|
||||
Reference in New Issue
Block a user