From f97f1dc5c38bd7e3c0b731404040a2a8e56a20e5 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 10 Sep 2022 19:57:21 +0200 Subject: [PATCH 1/4] Test CatboostRegressorMultiTarget, simplify test setup via parametrization --- tests/freqai/test_freqai_interface.py | 44 ++++++--------------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 5f8eeb086..115ad0ea8 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -46,10 +46,18 @@ def test_extract_data_and_train_model_LightGBM(mocker, freqai_conf): shutil.rmtree(Path(freqai.dk.full_path)) -def test_extract_data_and_train_model_LightGBMMultiModel(mocker, freqai_conf): +@pytest.mark.parametrize('model', [ + 'LightGBMRegressorMultiTarget', + 'XGBoostRegressorMultiTarget', + 'CatboostRegressorMultiTarget', + ]) +def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model): + if is_arm() and model == 'CatboostRegressorMultiTarget': + pytest.skip("CatBoost is not supported on ARM") + freqai_conf.update({"timerange": "20180110-20180130"}) freqai_conf.update({"strategy": "freqai_test_multimodel_strat"}) - freqai_conf.update({"freqaimodel": "LightGBMRegressorMultiTarget"}) + freqai_conf.update({"freqaimodel": model}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange) @@ -205,38 +213,6 @@ def test_extract_data_and_train_model_XGBoostRegressor(mocker, freqai_conf): shutil.rmtree(Path(freqai.dk.full_path)) -def test_extract_data_and_train_model_XGBoostRegressorMultiModel(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "XGBoostRegressorMultiTarget"}) - freqai_conf.update({"strategy": "freqai_test_multimodel_strat"}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model( - new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) - - assert len(freqai.dk.label_list) == 2 - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() - assert len(freqai.dk.data['training_features_list']) == 26 - - shutil.rmtree(Path(freqai.dk.full_path)) - - def test_start_backtesting(mocker, freqai_conf): freqai_conf.update({"timerange": "20180120-20180130"}) freqai_conf.get("freqai", {}).update({"save_backtest_models": True}) From 88892ba663ea8805e06ecd2d101c8982f595ad5f Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 10 Sep 2022 20:06:52 +0200 Subject: [PATCH 2/4] Parametrize regressor tests --- tests/freqai/test_freqai_interface.py | 76 ++++----------------------- 1 file changed, 11 insertions(+), 65 deletions(-) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 115ad0ea8..bd921c16b 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -17,8 +17,18 @@ def is_arm() -> bool: return "arm" in machine or "aarch64" in machine -def test_extract_data_and_train_model_LightGBM(mocker, freqai_conf): +@pytest.mark.parametrize('model', [ + 'LightGBMRegressor', + 'XGBoostRegressor', + 'CatboostRegressor', + ]) +def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model): + if is_arm() and model == 'CatboostRegressor': + pytest.skip("CatBoost is not supported on ARM") + + freqai_conf.update({"freqaimodel": model}) freqai_conf.update({"timerange": "20180110-20180130"}) + freqai_conf.update({"strategy": "freqai_test_strat"}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) @@ -86,39 +96,6 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model): shutil.rmtree(Path(freqai.dk.full_path)) -@pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...") -def test_extract_data_and_train_model_Catboost(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "CatboostRegressor"}) - # freqai_conf.get('freqai', {}).update( - # {'model_training_parameters': {"n_estimators": 100, "verbose": 0}}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", - strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists() - - shutil.rmtree(Path(freqai.dk.full_path)) - - @pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...") def test_extract_data_and_train_model_CatboostClassifier(mocker, freqai_conf): freqai_conf.update({"timerange": "20180110-20180130"}) @@ -182,37 +159,6 @@ def test_extract_data_and_train_model_LightGBMClassifier(mocker, freqai_conf): shutil.rmtree(Path(freqai.dk.full_path)) -def test_extract_data_and_train_model_XGBoostRegressor(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "XGBoostRegressor"}) - freqai_conf.update({"strategy": "freqai_test_strat"}) - - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model( - new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() - - shutil.rmtree(Path(freqai.dk.full_path)) - - def test_start_backtesting(mocker, freqai_conf): freqai_conf.update({"timerange": "20180120-20180130"}) freqai_conf.get("freqai", {}).update({"save_backtest_models": True}) From b3fc1cfde94b9b8e67e79d1f35f4ce3f36a4e5e0 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 10 Sep 2022 20:17:57 +0200 Subject: [PATCH 3/4] Parametrize classifier tests --- tests/freqai/test_freqai_interface.py | 43 ++++++--------------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index bd921c16b..2a7cfeb73 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -96,42 +96,17 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model): shutil.rmtree(Path(freqai.dk.full_path)) -@pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...") -def test_extract_data_and_train_model_CatboostClassifier(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "CatboostClassifier"}) +@pytest.mark.parametrize('model', [ + 'LightGBMClassifier', + 'CatboostClassifier', + ]) +def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): + if is_arm() and model == 'CatboostClassifier': + pytest.skip("CatBoost is not supported on ARM") + + freqai_conf.update({"freqaimodel": model}) freqai_conf.update({"strategy": "freqai_test_classifier"}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", - strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists() - - shutil.rmtree(Path(freqai.dk.full_path)) - - -def test_extract_data_and_train_model_LightGBMClassifier(mocker, freqai_conf): freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "LightGBMClassifier"}) - freqai_conf.update({"strategy": "freqai_test_classifier"}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange) From 5a0cfee27e8de330d00168b6de86a2ac84e0ed7f Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sat, 10 Sep 2022 22:16:49 +0200 Subject: [PATCH 4/4] allow user to multithread jobs (advanced users only) --- .../freqai/base_models/FreqaiMultiOutputRegressor.py | 10 ---------- .../prediction_models/CatboostRegressorMultiTarget.py | 3 +++ .../prediction_models/LightGBMRegressorMultiTarget.py | 6 +++--- .../prediction_models/XGBoostRegressorMultiTarget.py | 3 +++ 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py b/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py index aa5dbe629..a9db81e31 100644 --- a/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py +++ b/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py @@ -36,9 +36,6 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor): y = self._validate_data(X="no_validation", y=y, multi_output=True) - # if is_classifier(self): - # check_classification_targets(y) - if y.ndim == 1: raise ValueError( "y must have at least two dimensions for " @@ -50,19 +47,12 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor): ): raise ValueError("Underlying estimator does not support sample weights.") - # fit_params_validated = _check_fit_params(X, fit_params) - if not fit_params: fit_params = [None] * y.shape[1] - # if not init_models: - # init_models = [None] * y.shape[1] - self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_fit_estimator)( self.estimator, X, y[:, i], sample_weight, **fit_params[i] - # init_model=init_models[i], eval_set=eval_sets[i], - # **fit_params_validated ) for i in range(y.shape[1]) ) diff --git a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index a376b2c33..7fa4e293e 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -60,6 +60,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel): {'eval_set': eval_sets[i], 'init_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=cbr) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) return model diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index 7a9b5c36a..37c6bb186 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -56,9 +56,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel): 'init_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=lgb) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) - # model = FreqaiMultiOutputRegressor(estimator=lgb) - # model.fit(X=X, y=y, sample_weight=sample_weight, init_models=init_models, - # eval_sets=eval_sets, eval_sample_weight=eval_weights) return model diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py index 38c478c0b..920745ec9 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py @@ -55,6 +55,9 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel): 'xgb_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=xgb) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) return model