From 68728409aa5c7b668fbb1228cfed117901a7f1a9 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 20 Mar 2023 18:04:14 +0200 Subject: [PATCH] add pytorch regressor test --- tests/freqai/test_freqai_interface.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 01aa0d1db..3407a5a95 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -48,11 +48,12 @@ def can_run_model(model: str) -> None: ('XGBoostRegressor', False, True, False, True, False, 10), ('XGBoostRFRegressor', False, False, False, True, False, 0), ('CatboostRegressor', False, False, False, True, True, 0), + ('MLPPyTorchRegressor', False, False, False, True, False, 0), ('ReinforcementLearner', False, True, False, True, False, 0), ('ReinforcementLearner_multiproc', False, False, False, True, False, 0), ('ReinforcementLearner_test_3ac', False, False, False, False, False, 0), ('ReinforcementLearner_test_3ac', False, False, False, True, False, 0), - ('ReinforcementLearner_test_4ac', False, False, False, True, False, 0) + ('ReinforcementLearner_test_4ac', False, False, False, True, False, 0), ]) def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32, can_short, shuffle, buffer): @@ -85,6 +86,9 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'test_3ac' in model or 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") + if 'MLPPyTorchRegressor' in model: + model_save_ext = 'zip' + strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange)