diff --git a/freqtrade/freqai/prediction_models/SKLearnRandomForestClassifier.py b/freqtrade/freqai/prediction_models/SKLearnRandomForestClassifier.py new file mode 100644 index 000000000..4462efc49 --- /dev/null +++ b/freqtrade/freqai/prediction_models/SKLearnRandomForestClassifier.py @@ -0,0 +1,82 @@ +import logging +from typing import Any, Dict, Tuple + +import numpy as np +import numpy.typing as npt +from pandas import DataFrame +from sklearn.ensemble import RandomForestClassifier +from sklearn.preprocessing import LabelEncoder + +from freqtrade.freqai.base_models.BaseClassifierModel import BaseClassifierModel +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen + + +logger = logging.getLogger(__name__) + + +class SKLearnRandomForestClassifier(BaseClassifierModel): + """ + User created prediction model. The class inherits IFreqaiModel, which + means it has full access to all Frequency AI functionality. Typically, + users would use this to override the common `fit()`, `train()`, or + `predict()` methods to add their custom data handling tools or change + various aspects of the training that cannot be configured via the + top level config.json file. + """ + + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: + """ + User sets up the training and test data to fit their desired model here + :param data_dictionary: the dictionary holding all data for train, test, + labels, weights + :param dk: The datakitchen object for the current coin/model + """ + + X = data_dictionary["train_features"].to_numpy() + y = data_dictionary["train_labels"].to_numpy()[:, 0] + + if self.freqai_info.get('data_split_parameters', {}).get('test_size', 0.1) == 0: + eval_set = None + else: + test_features = data_dictionary["test_features"].to_numpy() + test_labels = data_dictionary["test_labels"].to_numpy()[:, 0] + + eval_set = (test_features, test_labels) + + if self.freqai_info.get("continual_learning", False): + logger.warning("Continual learning is not supported for " + "SKLearnRandomForestClassifier, ignoring.") + + train_weights = data_dictionary["train_weights"] + + model = RandomForestClassifier(**self.model_training_parameters) + + model.fit(X=X, y=y, sample_weight=train_weights) + if eval_set: + logger.info("Score: %s", model.score(eval_set[0], eval_set[1])) + + return model + + def predict( + self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs + ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: + """ + Filter the prediction features data and predict with it. + :param unfiltered_df: Full dataframe for the current backtest period. + :return: + :pred_df: dataframe containing the predictions + :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove + data (NaNs) or felt uncertain about data (PCA and DI index) + """ + + (pred_df, dk.do_predict) = super().predict(unfiltered_df, dk, **kwargs) + + le = LabelEncoder() + label = dk.label_list[0] + labels_before = list(dk.data['labels_std'].keys()) + labels_after = le.fit_transform(labels_before).tolist() + pred_df[label] = le.inverse_transform(pred_df[label]) + pred_df = pred_df.rename( + columns={labels_after[i]: labels_before[i] for i in range(len(labels_before))}) + + return (pred_df, dk.do_predict) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 7638c03ed..9fc722ab1 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -176,6 +176,7 @@ def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model, s 'CatboostClassifier', 'XGBoostClassifier', 'XGBoostRFClassifier', + 'SKLearnRandomForestClassifier', 'PyTorchMLPClassifier', ]) def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index e94509b40..17b0399d9 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -1770,6 +1770,7 @@ def test_api_freqaimodels(botclient, tmp_path, mocker): {'name': 'LightGBMRegressorMultiTarget'}, {'name': 'ReinforcementLearner'}, {'name': 'ReinforcementLearner_multiproc'}, + {'name': 'SKlearnRandomForestClassifier'}, {'name': 'XGBoostClassifier'}, {'name': 'XGBoostRFClassifier'}, {'name': 'XGBoostRFRegressor'}, @@ -1788,6 +1789,7 @@ def test_api_freqaimodels(botclient, tmp_path, mocker): 'LightGBMRegressorMultiTarget', 'ReinforcementLearner', 'ReinforcementLearner_multiproc', + 'SKlearnRandomForestClassifier', 'XGBoostClassifier', 'XGBoostRFClassifier', 'XGBoostRFRegressor',