mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
pytorch - trainer - reomve max_n_eval_batches arg from estimate loss method
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@@ -53,7 +52,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.max_iters: int = kwargs.get("max_iters", 100)
|
self.max_iters: int = kwargs.get("max_iters", 100)
|
||||||
self.batch_size: int = kwargs.get("batch_size", 64)
|
self.batch_size: int = kwargs.get("batch_size", 64)
|
||||||
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
|
self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None) # TODO change this to n_batches
|
||||||
self.data_convertor = data_convertor
|
self.data_convertor = data_convertor
|
||||||
self.window_size: int = window_size
|
self.window_size: int = window_size
|
||||||
self.tb_logger = tb_logger
|
self.tb_logger = tb_logger
|
||||||
@@ -95,25 +94,16 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
if "test" in splits:
|
if "test" in splits:
|
||||||
self.estimate_loss(
|
self.estimate_loss(data_loaders_dictionary, "test")
|
||||||
data_loaders_dictionary,
|
|
||||||
self.max_n_eval_batches,
|
|
||||||
"test"
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_loss(
|
def estimate_loss(
|
||||||
self,
|
self,
|
||||||
data_loader_dictionary: Dict[str, DataLoader],
|
data_loader_dictionary: Dict[str, DataLoader],
|
||||||
max_n_eval_batches: Optional[int],
|
|
||||||
split: str,
|
split: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
n_batches = 0
|
|
||||||
for i, batch_data in enumerate(data_loader_dictionary[split]):
|
for i, batch_data in enumerate(data_loader_dictionary[split]):
|
||||||
if max_n_eval_batches and i > max_n_eval_batches:
|
|
||||||
n_batches += 1
|
|
||||||
break
|
|
||||||
xb, yb = batch_data
|
xb, yb = batch_data
|
||||||
xb.to(self.device)
|
xb.to(self.device)
|
||||||
yb.to(self.device)
|
yb.to(self.device)
|
||||||
@@ -158,8 +148,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
across different n_obs - the number of data points.
|
across different n_obs - the number of data points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_batches = math.ceil(n_obs // batch_size)
|
n_batches = n_obs // batch_size
|
||||||
epochs = math.ceil(n_iters // n_batches)
|
epochs = n_iters // n_batches
|
||||||
if epochs <= 10:
|
if epochs <= 10:
|
||||||
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
|
logger.warning("User set `max_iters` in such a way that the trainer will only perform "
|
||||||
f" {epochs} epochs. Please consider increasing this value accordingly")
|
f" {epochs} epochs. Please consider increasing this value accordingly")
|
||||||
|
|||||||
Reference in New Issue
Block a user