Merge pull request #8580 from freqtrade/feat/add-transformer

Add transformer to FreqAI
This commit is contained in:
Robert Caulk
2023-05-07 11:32:38 +02:00
committed by GitHub
14 changed files with 337 additions and 26 deletions

View File

@@ -50,7 +50,8 @@ def can_run_model(model: str) -> None:
('XGBoostRegressor', False, True, False, True, False, 10),
('XGBoostRFRegressor', False, False, False, True, False, 0),
('CatboostRegressor', False, False, False, True, True, 0),
('PyTorchMLPRegressor', False, False, False, True, False, 0),
('PyTorchMLPRegressor', False, False, False, False, False, 0),
('PyTorchTransformerRegressor', False, False, False, False, False, 0),
('ReinforcementLearner', False, True, False, True, False, 0),
('ReinforcementLearner_multiproc', False, False, False, True, False, 0),
('ReinforcementLearner_test_3ac', False, False, False, False, False, 0),
@@ -82,10 +83,13 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
freqai_conf["freqai"]["rl_config"]["drop_ohlc_from_features"] = True
if 'PyTorchMLPRegressor' in model:
if 'PyTorch' in model:
model_save_ext = 'zip'
pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters()
freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp)
if 'Transformer' in model:
# transformer model takes a window, unlike the MLP regressor
freqai_conf.update({"conv_width": 10})
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf)
@@ -228,6 +232,7 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
("XGBoostRegressor", 2, "freqai_test_strat"),
("CatboostRegressor", 2, "freqai_test_strat"),
("PyTorchMLPRegressor", 2, "freqai_test_strat"),
("PyTorchTransformerRegressor", 2, "freqai_test_strat"),
("ReinforcementLearner", 3, "freqai_rl_test_strat"),
("XGBoostClassifier", 2, "freqai_test_classifier"),
("LightGBMClassifier", 2, "freqai_test_classifier"),
@@ -253,9 +258,12 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog)
if 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
if 'PyTorchMLP' in model:
if 'PyTorch' in model:
pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters()
freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp)
if 'Transformer' in model:
# transformer model takes a window, unlike the MLP regressor
freqai_conf.update({"conv_width": 10})
freqai_conf.get("freqai", {}).get("feature_parameters", {}).update(
{"indicator_periods_candles": [2]})