diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index ea7981405..b29d20112 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -74,16 +74,17 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): model.to(self.device) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) criterion = torch.nn.CrossEntropyLoss() - init_model = self.get_init_model(dk.pair) - trainer = PyTorchModelTrainer( - model=model, - optimizer=optimizer, - criterion=criterion, - model_meta_data={"class_names": class_names}, - device=self.device, - init_model=init_model, - data_convertor=self.data_convertor, - **self.trainer_kwargs, - ) + # check if continual_learning is activated, and retreive the model to continue training + trainer = self.get_init_model(dk.pair) + if trainer is None: + trainer = PyTorchModelTrainer( + model=model, + optimizer=optimizer, + criterion=criterion, + model_meta_data={"class_names": class_names}, + device=self.device, + data_convertor=self.data_convertor, + **self.trainer_kwargs, + ) trainer.fit(data_dictionary, self.splits) return trainer diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py index 64f0f4b03..6e1270102 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py @@ -69,15 +69,16 @@ class PyTorchMLPRegressor(BasePyTorchRegressor): model.to(self.device) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) criterion = torch.nn.MSELoss() - init_model = self.get_init_model(dk.pair) - trainer = PyTorchModelTrainer( - model=model, - optimizer=optimizer, - criterion=criterion, - device=self.device, - init_model=init_model, - data_convertor=self.data_convertor, - **self.trainer_kwargs, - ) + # check if continual_learning is activated, and retreive the model to continue training + trainer = self.get_init_model(dk.pair) + if trainer is None: + trainer = PyTorchModelTrainer( + model=model, + optimizer=optimizer, + criterion=criterion, + device=self.device, + data_convertor=self.data_convertor, + **self.trainer_kwargs, + ) trainer.fit(data_dictionary, self.splits) return trainer diff --git a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py index e760f6e68..5e84ada72 100644 --- a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py @@ -75,17 +75,18 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor): model.to(self.device) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) criterion = torch.nn.MSELoss() - init_model = self.get_init_model(dk.pair) - trainer = PyTorchTransformerTrainer( - model=model, - optimizer=optimizer, - criterion=criterion, - device=self.device, - init_model=init_model, - data_convertor=self.data_convertor, - window_size=self.window_size, - **self.trainer_kwargs, - ) + # check if continual_learning is activated, and retreive the model to continue training + trainer = self.get_init_model(dk.pair) + if trainer is None: + trainer = PyTorchTransformerTrainer( + model=model, + optimizer=optimizer, + criterion=criterion, + device=self.device, + data_convertor=self.data_convertor, + window_size=self.window_size, + **self.trainer_kwargs, + ) trainer.fit(data_dictionary, self.splits) return trainer diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index a3b0d9b9c..a25fa45bc 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -25,7 +25,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): optimizer: Optimizer, criterion: nn.Module, device: str, - init_model: Dict, + # init_model: Dict, data_convertor: PyTorchDataConvertor, model_meta_data: Dict[str, Any] = {}, window_size: int = 1, @@ -56,8 +56,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface): self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) self.data_convertor = data_convertor self.window_size: int = window_size - if init_model: - self.load_from_checkpoint(init_model) + # if init_model: + # self.load_from_checkpoint(init_model) def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]): """