diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index 2855dda33..3abc56fb1 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -34,13 +34,15 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): """ super().__init__(**kwargs) - trainer_kwargs = self.freqai_info.get("trainer_kwargs", {}) - self.n_hidden: int = trainer_kwargs.get("n_hidden", 1024) - self.max_iters: int = trainer_kwargs.get("max_iters", 100) - self.batch_size: int = trainer_kwargs.get("batch_size", 64) - self.learning_rate: float = trainer_kwargs.get("learning_rate", 3e-4) - self.max_n_eval_batches: Optional[int] = trainer_kwargs.get("max_n_eval_batches", None) - self.model_kwargs: Dict = trainer_kwargs.get("model_kwargs", {}) + model_training_params = self.freqai_info.get("model_training_parameters", {}) + self.n_hidden: int = model_training_params.get("n_hidden", 1024) + self.max_iters: int = model_training_params.get("max_iters", 100) + self.batch_size: int = model_training_params.get("batch_size", 64) + self.learning_rate: float = model_training_params.get("learning_rate", 3e-4) + self.max_n_eval_batches: Optional[int] = model_training_params.get( + "max_n_eval_batches", None + ) + self.model_kwargs: Dict = model_training_params.get("model_kwargs", {}) self.class_name_to_index = None self.index_to_class_name = None