diff --git a/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py b/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py index aa5dbe629..a9db81e31 100644 --- a/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py +++ b/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py @@ -36,9 +36,6 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor): y = self._validate_data(X="no_validation", y=y, multi_output=True) - # if is_classifier(self): - # check_classification_targets(y) - if y.ndim == 1: raise ValueError( "y must have at least two dimensions for " @@ -50,19 +47,12 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor): ): raise ValueError("Underlying estimator does not support sample weights.") - # fit_params_validated = _check_fit_params(X, fit_params) - if not fit_params: fit_params = [None] * y.shape[1] - # if not init_models: - # init_models = [None] * y.shape[1] - self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_fit_estimator)( self.estimator, X, y[:, i], sample_weight, **fit_params[i] - # init_model=init_models[i], eval_set=eval_sets[i], - # **fit_params_validated ) for i in range(y.shape[1]) ) diff --git a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index a376b2c33..7fa4e293e 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -60,6 +60,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel): {'eval_set': eval_sets[i], 'init_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=cbr) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) return model diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index 7a9b5c36a..37c6bb186 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -56,9 +56,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel): 'init_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=lgb) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) - # model = FreqaiMultiOutputRegressor(estimator=lgb) - # model.fit(X=X, y=y, sample_weight=sample_weight, init_models=init_models, - # eval_sets=eval_sets, eval_sample_weight=eval_weights) return model diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py index 38c478c0b..920745ec9 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py @@ -55,6 +55,9 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel): 'xgb_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=xgb) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) return model diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 5f8eeb086..2a7cfeb73 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -17,166 +17,17 @@ def is_arm() -> bool: return "arm" in machine or "aarch64" in machine -def test_extract_data_and_train_model_LightGBM(mocker, freqai_conf): +@pytest.mark.parametrize('model', [ + 'LightGBMRegressor', + 'XGBoostRegressor', + 'CatboostRegressor', + ]) +def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model): + if is_arm() and model == 'CatboostRegressor': + pytest.skip("CatBoost is not supported on ARM") + + freqai_conf.update({"freqaimodel": model}) freqai_conf.update({"timerange": "20180110-20180130"}) - - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model( - new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() - - shutil.rmtree(Path(freqai.dk.full_path)) - - -def test_extract_data_and_train_model_LightGBMMultiModel(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"strategy": "freqai_test_multimodel_strat"}) - freqai_conf.update({"freqaimodel": "LightGBMRegressorMultiTarget"}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model( - new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) - - assert len(freqai.dk.label_list) == 2 - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() - assert len(freqai.dk.data['training_features_list']) == 26 - - shutil.rmtree(Path(freqai.dk.full_path)) - - -@pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...") -def test_extract_data_and_train_model_Catboost(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "CatboostRegressor"}) - # freqai_conf.get('freqai', {}).update( - # {'model_training_parameters': {"n_estimators": 100, "verbose": 0}}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", - strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists() - - shutil.rmtree(Path(freqai.dk.full_path)) - - -@pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...") -def test_extract_data_and_train_model_CatboostClassifier(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "CatboostClassifier"}) - freqai_conf.update({"strategy": "freqai_test_classifier"}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", - strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists() - - shutil.rmtree(Path(freqai.dk.full_path)) - - -def test_extract_data_and_train_model_LightGBMClassifier(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "LightGBMClassifier"}) - freqai_conf.update({"strategy": "freqai_test_classifier"}) - strategy = get_patched_freqai_strategy(mocker, freqai_conf) - exchange = get_patched_exchange(mocker, freqai_conf) - strategy.dp = DataProvider(freqai_conf, exchange) - - strategy.freqai_info = freqai_conf.get("freqai", {}) - freqai = strategy.freqai - freqai.live = True - freqai.dk = FreqaiDataKitchen(freqai_conf) - timerange = TimeRange.parse_timerange("20180110-20180130") - freqai.dd.load_all_pair_histories(timerange, freqai.dk) - - freqai.dd.pair_dict = MagicMock() - - data_load_timerange = TimeRange.parse_timerange("20180110-20180130") - new_timerange = TimeRange.parse_timerange("20180120-20180130") - - freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", - strategy, freqai.dk, data_load_timerange) - - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() - assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists() - - shutil.rmtree(Path(freqai.dk.full_path)) - - -def test_extract_data_and_train_model_XGBoostRegressor(mocker, freqai_conf): - freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "XGBoostRegressor"}) freqai_conf.update({"strategy": "freqai_test_strat"}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) @@ -205,10 +56,18 @@ def test_extract_data_and_train_model_XGBoostRegressor(mocker, freqai_conf): shutil.rmtree(Path(freqai.dk.full_path)) -def test_extract_data_and_train_model_XGBoostRegressorMultiModel(mocker, freqai_conf): +@pytest.mark.parametrize('model', [ + 'LightGBMRegressorMultiTarget', + 'XGBoostRegressorMultiTarget', + 'CatboostRegressorMultiTarget', + ]) +def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model): + if is_arm() and model == 'CatboostRegressorMultiTarget': + pytest.skip("CatBoost is not supported on ARM") + freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "XGBoostRegressorMultiTarget"}) freqai_conf.update({"strategy": "freqai_test_multimodel_strat"}) + freqai_conf.update({"freqaimodel": model}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange) @@ -237,6 +96,44 @@ def test_extract_data_and_train_model_XGBoostRegressorMultiModel(mocker, freqai_ shutil.rmtree(Path(freqai.dk.full_path)) +@pytest.mark.parametrize('model', [ + 'LightGBMClassifier', + 'CatboostClassifier', + ]) +def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): + if is_arm() and model == 'CatboostClassifier': + pytest.skip("CatBoost is not supported on ARM") + + freqai_conf.update({"freqaimodel": model}) + freqai_conf.update({"strategy": "freqai_test_classifier"}) + freqai_conf.update({"timerange": "20180110-20180130"}) + strategy = get_patched_freqai_strategy(mocker, freqai_conf) + exchange = get_patched_exchange(mocker, freqai_conf) + strategy.dp = DataProvider(freqai_conf, exchange) + + strategy.freqai_info = freqai_conf.get("freqai", {}) + freqai = strategy.freqai + freqai.live = True + freqai.dk = FreqaiDataKitchen(freqai_conf) + timerange = TimeRange.parse_timerange("20180110-20180130") + freqai.dd.load_all_pair_histories(timerange, freqai.dk) + + freqai.dd.pair_dict = MagicMock() + + data_load_timerange = TimeRange.parse_timerange("20180110-20180130") + new_timerange = TimeRange.parse_timerange("20180120-20180130") + + freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", + strategy, freqai.dk, data_load_timerange) + + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists() + + shutil.rmtree(Path(freqai.dk.full_path)) + + def test_start_backtesting(mocker, freqai_conf): freqai_conf.update({"timerange": "20180120-20180130"}) freqai_conf.get("freqai", {}).update({"save_backtest_models": True})