fix bug in continual_learning for PyTorch* models

This commit is contained in:
robcaulk
2023-05-13 11:14:16 +00:00
parent ad2080ab3e
commit 3ae3cc63df
4 changed files with 38 additions and 35 deletions

View File

@@ -74,16 +74,17 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": class_names},
device=self.device,
init_model=init_model,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
# check if continual_learning is activated, and retreive the model to continue training
trainer = self.get_init_model(dk.pair)
if trainer is None:
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": class_names},
device=self.device,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary, self.splits)
return trainer

View File

@@ -69,15 +69,16 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.MSELoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
device=self.device,
init_model=init_model,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
# check if continual_learning is activated, and retreive the model to continue training
trainer = self.get_init_model(dk.pair)
if trainer is None:
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
device=self.device,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary, self.splits)
return trainer

View File

@@ -75,17 +75,18 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.MSELoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchTransformerTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
device=self.device,
init_model=init_model,
data_convertor=self.data_convertor,
window_size=self.window_size,
**self.trainer_kwargs,
)
# check if continual_learning is activated, and retreive the model to continue training
trainer = self.get_init_model(dk.pair)
if trainer is None:
trainer = PyTorchTransformerTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
device=self.device,
data_convertor=self.data_convertor,
window_size=self.window_size,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary, self.splits)
return trainer

View File

@@ -25,7 +25,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
optimizer: Optimizer,
criterion: nn.Module,
device: str,
init_model: Dict,
# init_model: Dict,
data_convertor: PyTorchDataConvertor,
model_meta_data: Dict[str, Any] = {},
window_size: int = 1,
@@ -56,8 +56,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
self.data_convertor = data_convertor
self.window_size: int = window_size
if init_model:
self.load_from_checkpoint(init_model)
# if init_model:
# self.load_from_checkpoint(init_model)
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
"""