mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-01-28 18:00:23 +00:00
fix bug in continual_learning for PyTorch* models
This commit is contained in:
@@ -74,16 +74,17 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
||||
model.to(self.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
init_model = self.get_init_model(dk.pair)
|
||||
trainer = PyTorchModelTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
model_meta_data={"class_names": class_names},
|
||||
device=self.device,
|
||||
init_model=init_model,
|
||||
data_convertor=self.data_convertor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
# check if continual_learning is activated, and retreive the model to continue training
|
||||
trainer = self.get_init_model(dk.pair)
|
||||
if trainer is None:
|
||||
trainer = PyTorchModelTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
model_meta_data={"class_names": class_names},
|
||||
device=self.device,
|
||||
data_convertor=self.data_convertor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
return trainer
|
||||
|
||||
@@ -69,15 +69,16 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
|
||||
model.to(self.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
||||
criterion = torch.nn.MSELoss()
|
||||
init_model = self.get_init_model(dk.pair)
|
||||
trainer = PyTorchModelTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
device=self.device,
|
||||
init_model=init_model,
|
||||
data_convertor=self.data_convertor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
# check if continual_learning is activated, and retreive the model to continue training
|
||||
trainer = self.get_init_model(dk.pair)
|
||||
if trainer is None:
|
||||
trainer = PyTorchModelTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
device=self.device,
|
||||
data_convertor=self.data_convertor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
return trainer
|
||||
|
||||
@@ -75,17 +75,18 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
||||
model.to(self.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
||||
criterion = torch.nn.MSELoss()
|
||||
init_model = self.get_init_model(dk.pair)
|
||||
trainer = PyTorchTransformerTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
device=self.device,
|
||||
init_model=init_model,
|
||||
data_convertor=self.data_convertor,
|
||||
window_size=self.window_size,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
# check if continual_learning is activated, and retreive the model to continue training
|
||||
trainer = self.get_init_model(dk.pair)
|
||||
if trainer is None:
|
||||
trainer = PyTorchTransformerTrainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
device=self.device,
|
||||
data_convertor=self.data_convertor,
|
||||
window_size=self.window_size,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
trainer.fit(data_dictionary, self.splits)
|
||||
return trainer
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
optimizer: Optimizer,
|
||||
criterion: nn.Module,
|
||||
device: str,
|
||||
init_model: Dict,
|
||||
# init_model: Dict,
|
||||
data_convertor: PyTorchDataConvertor,
|
||||
model_meta_data: Dict[str, Any] = {},
|
||||
window_size: int = 1,
|
||||
@@ -56,8 +56,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
||||
self.data_convertor = data_convertor
|
||||
self.window_size: int = window_size
|
||||
if init_model:
|
||||
self.load_from_checkpoint(init_model)
|
||||
# if init_model:
|
||||
# self.load_from_checkpoint(init_model)
|
||||
|
||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user