ensure data kitchen thread count is propagated to pipeline

This commit is contained in:
robcaulk
2023-06-08 12:33:08 +02:00
parent 88337b6c5e
commit 33b028b104
7 changed files with 12 additions and 32 deletions

View File

@@ -53,7 +53,7 @@ class BaseClassifierModel(IFreqaiModel):
dd = dk.make_train_test_datasets(features_filtered, labels_filtered)
if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
dk.fit_labels()
dk.feature_pipeline = self.define_data_pipeline()
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
(dd["train_features"],
dd["train_labels"],

View File

@@ -189,7 +189,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
dk.fit_labels()
dk.feature_pipeline = self.define_data_pipeline()
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
(dd["train_features"],
dd["train_labels"],

View File

@@ -85,8 +85,8 @@ class BasePyTorchRegressor(BasePyTorchModel):
dd = dk.make_train_test_datasets(features_filtered, labels_filtered)
if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
dk.fit_labels()
dk.feature_pipeline = self.define_data_pipeline()
dk.label_pipeline = self.define_label_pipeline()
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
dd["train_labels"], _, _ = dk.label_pipeline.fit_transform(dd["train_labels"])
dd["test_labels"], _, _ = dk.label_pipeline.transform(dd["test_labels"])

View File

@@ -52,8 +52,8 @@ class BaseRegressionModel(IFreqaiModel):
dd = dk.make_train_test_datasets(features_filtered, labels_filtered)
if not self.freqai_info.get("fit_live_predictions_candles", 0) or not self.live:
dk.fit_labels()
dk.feature_pipeline = self.define_data_pipeline()
dk.label_pipeline = self.define_label_pipeline()
dk.feature_pipeline = self.define_data_pipeline(threads=dk.thread_count)
dk.label_pipeline = self.define_label_pipeline(threads=dk.thread_count)
(dd["train_features"],
dd["train_labels"],