mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-12-18 22:01:15 +00:00
Merge pull request #8580 from freqtrade/feat/add-transformer
Add transformer to FreqAI
This commit is contained in:
@@ -75,6 +75,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
||||
dk.data_dictionary["prediction_features"],
|
||||
device=self.device
|
||||
)
|
||||
self.model.model.eval()
|
||||
logits = self.model.model(x)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
predicted_classes = torch.argmax(probs, dim=-1)
|
||||
|
||||
@@ -27,6 +27,7 @@ class BasePyTorchModel(IFreqaiModel, ABC):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
test_size = self.freqai_info.get('data_split_parameters', {}).get('test_size')
|
||||
self.splits = ["train", "test"] if test_size != 0 else ["train"]
|
||||
self.window_size = self.freqai_info.get("conv_width", 1)
|
||||
|
||||
def train(
|
||||
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||
|
||||
@@ -44,6 +44,7 @@ class BasePyTorchRegressor(BasePyTorchModel):
|
||||
dk.data_dictionary["prediction_features"],
|
||||
device=self.device
|
||||
)
|
||||
self.model.model.eval()
|
||||
y = self.model.model(x)
|
||||
pred_df = DataFrame(y.detach().tolist(), columns=[dk.label_list[0]])
|
||||
return (pred_df, dk.do_predict)
|
||||
|
||||
Reference in New Issue
Block a user