mirror of
https://github.com/freqtrade/freqtrade.git
synced 2026-02-24 21:30:51 +00:00
pytorch - trainer - clean code
This commit is contained in:
@@ -86,7 +86,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
)
|
||||
|
||||
batch_counter = 0
|
||||
for epoch in range(n_epochs):
|
||||
for _ in range(n_epochs):
|
||||
for _, batch_data in enumerate(data_loaders_dictionary["train"]):
|
||||
xb, yb = batch_data
|
||||
xb = xb.to(self.device)
|
||||
@@ -171,7 +171,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
||||
"""
|
||||
- Saving any nn.Module state_dict
|
||||
- Saving model_meta_data, this dict should contain any additional data that the
|
||||
user needs to store. e.g class_names for classification models.
|
||||
user needs to store. e.g. class_names for classification models.
|
||||
"""
|
||||
|
||||
torch.save({
|
||||
|
||||
Reference in New Issue
Block a user