mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-12-18 22:01:15 +00:00
chore: update freqai to modern typing syntax
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
|||||||
import random
|
import random
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Type, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -89,7 +89,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
self.fee = fee
|
self.fee = fee
|
||||||
|
|
||||||
# set here to default 5Ac, but all children envs can override this
|
# set here to default 5Ac, but all children envs can override this
|
||||||
self.actions: Type[Enum] = BaseActions
|
self.actions: type[Enum] = BaseActions
|
||||||
self.tensorboard_metrics: dict = {}
|
self.tensorboard_metrics: dict = {}
|
||||||
self.can_short: bool = can_short
|
self.can_short: bool = can_short
|
||||||
self.live: bool = live
|
self.live: bool = live
|
||||||
@@ -163,7 +163,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
Unique to the environment action count. Must be inherited.
|
Unique to the environment action count. Must be inherited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def action_masks(self) -> List[bool]:
|
def action_masks(self) -> list[bool]:
|
||||||
return [self._is_valid(action.value) for action in self.actions]
|
return [self._is_valid(action.value) for action in self.actions]
|
||||||
|
|
||||||
def seed(self, seed: int = 1):
|
def seed(self, seed: int = 1):
|
||||||
@@ -375,7 +375,7 @@ class BaseEnvironment(gym.Env):
|
|||||||
def current_price(self) -> float:
|
def current_price(self) -> float:
|
||||||
return self.prices.iloc[self._current_tick].open
|
return self.prices.iloc[self._current_tick].open
|
||||||
|
|
||||||
def get_actions(self) -> Type[Enum]:
|
def get_actions(self) -> type[Enum]:
|
||||||
"""
|
"""
|
||||||
Used by SubprocVecEnv to get actions from
|
Used by SubprocVecEnv to get actions from
|
||||||
initialized env for tensorboard callback
|
initialized env for tensorboard callback
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -114,7 +114,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
training_filter=True,
|
training_filter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
dd: Dict[str, Any] = dk.make_train_test_datasets(features_filtered, labels_filtered)
|
dd: dict[str, Any] = dk.make_train_test_datasets(features_filtered, labels_filtered)
|
||||||
self.df_raw = copy.deepcopy(dd["train_features"])
|
self.df_raw = copy.deepcopy(dd["train_features"])
|
||||||
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
dk.fit_labels() # FIXME useless for now, but just satiating append methods
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
def set_train_and_eval_environments(
|
def set_train_and_eval_environments(
|
||||||
self,
|
self,
|
||||||
data_dictionary: Dict[str, DataFrame],
|
data_dictionary: dict[str, DataFrame],
|
||||||
prices_train: DataFrame,
|
prices_train: DataFrame,
|
||||||
prices_test: DataFrame,
|
prices_test: DataFrame,
|
||||||
dk: FreqaiDataKitchen,
|
dk: FreqaiDataKitchen,
|
||||||
@@ -183,7 +183,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
actions = self.train_env.get_actions()
|
actions = self.train_env.get_actions()
|
||||||
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
self.tensorboard_callback = TensorboardCallback(verbose=1, actions=actions)
|
||||||
|
|
||||||
def pack_env_dict(self, pair: str) -> Dict[str, Any]:
|
def pack_env_dict(self, pair: str) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create dictionary of environment arguments
|
Create dictionary of environment arguments
|
||||||
"""
|
"""
|
||||||
@@ -204,7 +204,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
return env_info
|
return env_info
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
def fit(self, data_dictionary: dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
"""
|
"""
|
||||||
Agent customizations and abstract Reinforcement Learning customizations
|
Agent customizations and abstract Reinforcement Learning customizations
|
||||||
go in here. Abstract method, so this function must be overridden by
|
go in here. Abstract method, so this function must be overridden by
|
||||||
@@ -212,7 +212,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_state_info(self, pair: str) -> Tuple[float, float, int]:
|
def get_state_info(self, pair: str) -> tuple[float, float, int]:
|
||||||
"""
|
"""
|
||||||
State info during dry/live (not backtesting) which is fed back
|
State info during dry/live (not backtesting) which is fed back
|
||||||
into the model.
|
into the model.
|
||||||
@@ -250,7 +250,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_dataframe: Full dataframe for the current backtest period.
|
:param unfiltered_dataframe: Full dataframe for the current backtest period.
|
||||||
@@ -303,7 +303,7 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
def build_ohlc_price_dataframes(
|
def build_ohlc_price_dataframes(
|
||||||
self, data_dictionary: dict, pair: str, dk: FreqaiDataKitchen
|
self, data_dictionary: dict, pair: str, dk: FreqaiDataKitchen
|
||||||
) -> Tuple[DataFrame, DataFrame]:
|
) -> tuple[DataFrame, DataFrame]:
|
||||||
"""
|
"""
|
||||||
Builds the train prices and test prices for the environment.
|
Builds the train prices and test prices for the environment.
|
||||||
"""
|
"""
|
||||||
@@ -482,13 +482,13 @@ class BaseReinforcementLearningModel(IFreqaiModel):
|
|||||||
|
|
||||||
|
|
||||||
def make_env(
|
def make_env(
|
||||||
MyRLEnv: Type[BaseEnvironment],
|
MyRLEnv: type[BaseEnvironment],
|
||||||
env_id: str,
|
env_id: str,
|
||||||
rank: int,
|
rank: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
train_df: DataFrame,
|
train_df: DataFrame,
|
||||||
price: DataFrame,
|
price: DataFrame,
|
||||||
env_info: Dict[str, Any] = {},
|
env_info: dict[str, Any] = {},
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""
|
"""
|
||||||
Utility function for multiprocessed env.
|
Utility function for multiprocessed env.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -86,7 +86,7 @@ class BaseClassifierModel(IFreqaiModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -44,7 +44,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param dk: dk: The datakitchen object
|
:param dk: dk: The datakitchen object
|
||||||
@@ -100,9 +100,9 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
|||||||
|
|
||||||
def encode_class_names(
|
def encode_class_names(
|
||||||
self,
|
self,
|
||||||
data_dictionary: Dict[str, pd.DataFrame],
|
data_dictionary: dict[str, pd.DataFrame],
|
||||||
dk: FreqaiDataKitchen,
|
dk: FreqaiDataKitchen,
|
||||||
class_names: List[str],
|
class_names: list[str],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
encode class name, str -> int
|
encode class name, str -> int
|
||||||
@@ -119,7 +119,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def assert_valid_class_names(target_column: pd.Series, class_names: List[str]):
|
def assert_valid_class_names(target_column: pd.Series, class_names: list[str]):
|
||||||
non_defined_labels = set(target_column) - set(class_names)
|
non_defined_labels = set(target_column) - set(class_names)
|
||||||
if len(non_defined_labels) != 0:
|
if len(non_defined_labels) != 0:
|
||||||
raise OperationalException(
|
raise OperationalException(
|
||||||
@@ -127,7 +127,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
|||||||
f"expecting labels: {class_names}",
|
f"expecting labels: {class_names}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode_class_names(self, class_ints: torch.Tensor) -> List[str]:
|
def decode_class_names(self, class_ints: torch.Tensor) -> list[str]:
|
||||||
"""
|
"""
|
||||||
decode class name, int -> str
|
decode class name, int -> str
|
||||||
"""
|
"""
|
||||||
@@ -141,14 +141,14 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
|||||||
|
|
||||||
def convert_label_column_to_int(
|
def convert_label_column_to_int(
|
||||||
self,
|
self,
|
||||||
data_dictionary: Dict[str, pd.DataFrame],
|
data_dictionary: dict[str, pd.DataFrame],
|
||||||
dk: FreqaiDataKitchen,
|
dk: FreqaiDataKitchen,
|
||||||
class_names: List[str],
|
class_names: list[str],
|
||||||
):
|
):
|
||||||
self.init_class_names_to_index_mapping(class_names)
|
self.init_class_names_to_index_mapping(class_names)
|
||||||
self.encode_class_names(data_dictionary, dk, class_names)
|
self.encode_class_names(data_dictionary, dk, class_names)
|
||||||
|
|
||||||
def get_class_names(self) -> List[str]:
|
def get_class_names(self) -> list[str]:
|
||||||
if not self.class_names:
|
if not self.class_names:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"self.class_names is empty, "
|
"self.class_names is empty, "
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -24,7 +24,7 @@ class BasePyTorchRegressor(BasePyTorchModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -88,7 +88,7 @@ class BaseRegressionModel(IFreqaiModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import threading
|
|||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Tuple, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -69,14 +69,14 @@ class FreqaiDataDrawer:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.freqai_info = config.get("freqai", {})
|
self.freqai_info = config.get("freqai", {})
|
||||||
# dictionary holding all pair metadata necessary to load in from disk
|
# dictionary holding all pair metadata necessary to load in from disk
|
||||||
self.pair_dict: Dict[str, pair_info] = {}
|
self.pair_dict: dict[str, pair_info] = {}
|
||||||
# dictionary holding all actively inferenced models in memory given a model filename
|
# dictionary holding all actively inferenced models in memory given a model filename
|
||||||
self.model_dictionary: Dict[str, Any] = {}
|
self.model_dictionary: dict[str, Any] = {}
|
||||||
# all additional metadata that we want to keep in ram
|
# all additional metadata that we want to keep in ram
|
||||||
self.meta_data_dictionary: Dict[str, Dict[str, Any]] = {}
|
self.meta_data_dictionary: dict[str, dict[str, Any]] = {}
|
||||||
self.model_return_values: Dict[str, DataFrame] = {}
|
self.model_return_values: dict[str, DataFrame] = {}
|
||||||
self.historic_data: Dict[str, Dict[str, DataFrame]] = {}
|
self.historic_data: dict[str, dict[str, DataFrame]] = {}
|
||||||
self.historic_predictions: Dict[str, DataFrame] = {}
|
self.historic_predictions: dict[str, DataFrame] = {}
|
||||||
self.full_path = full_path
|
self.full_path = full_path
|
||||||
self.historic_predictions_path = Path(self.full_path / "historic_predictions.pkl")
|
self.historic_predictions_path = Path(self.full_path / "historic_predictions.pkl")
|
||||||
self.historic_predictions_bkp_path = Path(
|
self.historic_predictions_bkp_path = Path(
|
||||||
@@ -87,14 +87,14 @@ class FreqaiDataDrawer:
|
|||||||
self.metric_tracker_path = Path(self.full_path / "metric_tracker.json")
|
self.metric_tracker_path = Path(self.full_path / "metric_tracker.json")
|
||||||
self.load_drawer_from_disk()
|
self.load_drawer_from_disk()
|
||||||
self.load_historic_predictions_from_disk()
|
self.load_historic_predictions_from_disk()
|
||||||
self.metric_tracker: Dict[str, Dict[str, Dict[str, list]]] = {}
|
self.metric_tracker: dict[str, dict[str, dict[str, list]]] = {}
|
||||||
self.load_metric_tracker_from_disk()
|
self.load_metric_tracker_from_disk()
|
||||||
self.training_queue: Dict[str, int] = {}
|
self.training_queue: dict[str, int] = {}
|
||||||
self.history_lock = threading.Lock()
|
self.history_lock = threading.Lock()
|
||||||
self.save_lock = threading.Lock()
|
self.save_lock = threading.Lock()
|
||||||
self.pair_dict_lock = threading.Lock()
|
self.pair_dict_lock = threading.Lock()
|
||||||
self.metric_tracker_lock = threading.Lock()
|
self.metric_tracker_lock = threading.Lock()
|
||||||
self.old_DBSCAN_eps: Dict[str, float] = {}
|
self.old_DBSCAN_eps: dict[str, float] = {}
|
||||||
self.empty_pair_dict: pair_info = {
|
self.empty_pair_dict: pair_info = {
|
||||||
"model_filename": "",
|
"model_filename": "",
|
||||||
"trained_timestamp": 0,
|
"trained_timestamp": 0,
|
||||||
@@ -228,7 +228,7 @@ class FreqaiDataDrawer:
|
|||||||
self.pair_dict, fp, default=self.np_encoder, number_mode=rapidjson.NM_NATIVE
|
self.pair_dict, fp, default=self.np_encoder, number_mode=rapidjson.NM_NATIVE
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_global_metadata_to_disk(self, metadata: Dict[str, Any]):
|
def save_global_metadata_to_disk(self, metadata: dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Save global metadata json to disk
|
Save global metadata json to disk
|
||||||
"""
|
"""
|
||||||
@@ -242,7 +242,7 @@ class FreqaiDataDrawer:
|
|||||||
if isinstance(obj, np.generic):
|
if isinstance(obj, np.generic):
|
||||||
return obj.item()
|
return obj.item()
|
||||||
|
|
||||||
def get_pair_dict_info(self, pair: str) -> Tuple[str, int]:
|
def get_pair_dict_info(self, pair: str) -> tuple[str, int]:
|
||||||
"""
|
"""
|
||||||
Locate and load existing model metadata from persistent storage. If not located,
|
Locate and load existing model metadata from persistent storage. If not located,
|
||||||
create a new one and append the current pair to it and prepare it for its first
|
create a new one and append the current pair to it and prepare it for its first
|
||||||
@@ -446,7 +446,7 @@ class FreqaiDataDrawer:
|
|||||||
|
|
||||||
pattern = re.compile(r"sub-train-(\w+)_(\d{10})")
|
pattern = re.compile(r"sub-train-(\w+)_(\d{10})")
|
||||||
|
|
||||||
delete_dict: Dict[str, Any] = {}
|
delete_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
for directory in model_folders:
|
for directory in model_folders:
|
||||||
result = pattern.match(str(directory.name))
|
result = pattern.match(str(directory.name))
|
||||||
@@ -704,7 +704,7 @@ class FreqaiDataDrawer:
|
|||||||
|
|
||||||
def get_base_and_corr_dataframes(
|
def get_base_and_corr_dataframes(
|
||||||
self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen
|
self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen
|
||||||
) -> Tuple[Dict[Any, Any], Dict[Any, Any]]:
|
) -> tuple[dict[Any, Any], dict[Any, Any]]:
|
||||||
"""
|
"""
|
||||||
Searches through our historic_data in memory and returns the dataframes relevant
|
Searches through our historic_data in memory and returns the dataframes relevant
|
||||||
to the present pair.
|
to the present pair.
|
||||||
@@ -713,8 +713,8 @@ class FreqaiDataDrawer:
|
|||||||
:param metadata: dict = strategy furnished pair metadata
|
:param metadata: dict = strategy furnished pair metadata
|
||||||
"""
|
"""
|
||||||
with self.history_lock:
|
with self.history_lock:
|
||||||
corr_dataframes: Dict[Any, Any] = {}
|
corr_dataframes: dict[Any, Any] = {}
|
||||||
base_dataframes: Dict[Any, Any] = {}
|
base_dataframes: dict[Any, Any] = {}
|
||||||
historic_data = self.historic_data
|
historic_data = self.historic_data
|
||||||
pairs = self.freqai_info["feature_parameters"].get("include_corr_pairlist", [])
|
pairs = self.freqai_info["feature_parameters"].get("include_corr_pairlist", [])
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import random
|
|||||||
import shutil
|
import shutil
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -64,15 +64,15 @@ class FreqaiDataKitchen:
|
|||||||
live: bool = False,
|
live: bool = False,
|
||||||
pair: str = "",
|
pair: str = "",
|
||||||
):
|
):
|
||||||
self.data: Dict[str, Any] = {}
|
self.data: dict[str, Any] = {}
|
||||||
self.data_dictionary: Dict[str, DataFrame] = {}
|
self.data_dictionary: dict[str, DataFrame] = {}
|
||||||
self.config = config
|
self.config = config
|
||||||
self.freqai_config: Dict[str, Any] = config["freqai"]
|
self.freqai_config: dict[str, Any] = config["freqai"]
|
||||||
self.full_df: DataFrame = DataFrame()
|
self.full_df: DataFrame = DataFrame()
|
||||||
self.append_df: DataFrame = DataFrame()
|
self.append_df: DataFrame = DataFrame()
|
||||||
self.data_path = Path()
|
self.data_path = Path()
|
||||||
self.label_list: List = []
|
self.label_list: list = []
|
||||||
self.training_features_list: List = []
|
self.training_features_list: list = []
|
||||||
self.model_filename: str = ""
|
self.model_filename: str = ""
|
||||||
self.backtesting_results_path = Path()
|
self.backtesting_results_path = Path()
|
||||||
self.backtest_predictions_folder: str = "backtesting_predictions"
|
self.backtest_predictions_folder: str = "backtesting_predictions"
|
||||||
@@ -104,9 +104,9 @@ class FreqaiDataKitchen:
|
|||||||
else:
|
else:
|
||||||
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
|
self.thread_count = self.freqai_config["data_kitchen_thread_count"]
|
||||||
self.train_dates: DataFrame = pd.DataFrame()
|
self.train_dates: DataFrame = pd.DataFrame()
|
||||||
self.unique_classes: Dict[str, list] = {}
|
self.unique_classes: dict[str, list] = {}
|
||||||
self.unique_class_list: list = []
|
self.unique_class_list: list = []
|
||||||
self.backtest_live_models_data: Dict[str, Any] = {}
|
self.backtest_live_models_data: dict[str, Any] = {}
|
||||||
|
|
||||||
def set_paths(
|
def set_paths(
|
||||||
self,
|
self,
|
||||||
@@ -127,7 +127,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
def make_train_test_datasets(
|
def make_train_test_datasets(
|
||||||
self, filtered_dataframe: DataFrame, labels: DataFrame
|
self, filtered_dataframe: DataFrame, labels: DataFrame
|
||||||
) -> Dict[Any, Any]:
|
) -> dict[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Given the dataframe for the full history for training, split the data into
|
Given the dataframe for the full history for training, split the data into
|
||||||
training and test data according to user specified parameters in configuration
|
training and test data according to user specified parameters in configuration
|
||||||
@@ -213,10 +213,10 @@ class FreqaiDataKitchen:
|
|||||||
def filter_features(
|
def filter_features(
|
||||||
self,
|
self,
|
||||||
unfiltered_df: DataFrame,
|
unfiltered_df: DataFrame,
|
||||||
training_feature_list: List,
|
training_feature_list: list,
|
||||||
label_list: List = list(),
|
label_list: list = list(),
|
||||||
training_filter: bool = True,
|
training_filter: bool = True,
|
||||||
) -> Tuple[DataFrame, DataFrame]:
|
) -> tuple[DataFrame, DataFrame]:
|
||||||
"""
|
"""
|
||||||
Filter the unfiltered dataframe to extract the user requested features/labels and properly
|
Filter the unfiltered dataframe to extract the user requested features/labels and properly
|
||||||
remove all NaNs. Any row with a NaN is removed from training dataset or replaced with
|
remove all NaNs. Any row with a NaN is removed from training dataset or replaced with
|
||||||
@@ -306,7 +306,7 @@ class FreqaiDataKitchen:
|
|||||||
test_labels: DataFrame,
|
test_labels: DataFrame,
|
||||||
train_weights: Any,
|
train_weights: Any,
|
||||||
test_weights: Any,
|
test_weights: Any,
|
||||||
) -> Dict:
|
) -> dict:
|
||||||
self.data_dictionary = {
|
self.data_dictionary = {
|
||||||
"train_features": train_df,
|
"train_features": train_df,
|
||||||
"test_features": test_df,
|
"test_features": test_df,
|
||||||
@@ -321,7 +321,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
def split_timerange(
|
def split_timerange(
|
||||||
self, tr: str, train_split: int = 28, bt_split: float = 7
|
self, tr: str, train_split: int = 28, bt_split: float = 7
|
||||||
) -> Tuple[list, list]:
|
) -> tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
Function which takes a single time range (tr) and splits it
|
Function which takes a single time range (tr) and splits it
|
||||||
into sub timeranges to train and backtest on based on user input
|
into sub timeranges to train and backtest on based on user input
|
||||||
@@ -535,7 +535,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
def check_if_new_training_required(
|
def check_if_new_training_required(
|
||||||
self, trained_timestamp: int
|
self, trained_timestamp: int
|
||||||
) -> Tuple[bool, TimeRange, TimeRange]:
|
) -> tuple[bool, TimeRange, TimeRange]:
|
||||||
time = datetime.now(tz=timezone.utc).timestamp()
|
time = datetime.now(tz=timezone.utc).timestamp()
|
||||||
trained_timerange = TimeRange()
|
trained_timerange = TimeRange()
|
||||||
data_load_timerange = TimeRange()
|
data_load_timerange = TimeRange()
|
||||||
@@ -603,7 +603,7 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
def extract_corr_pair_columns_from_populated_indicators(
|
def extract_corr_pair_columns_from_populated_indicators(
|
||||||
self, dataframe: DataFrame
|
self, dataframe: DataFrame
|
||||||
) -> Dict[str, DataFrame]:
|
) -> dict[str, DataFrame]:
|
||||||
"""
|
"""
|
||||||
Find the columns of the dataframe corresponding to the corr_pairlist, save them
|
Find the columns of the dataframe corresponding to the corr_pairlist, save them
|
||||||
in a dictionary to be reused and attached to other pairs.
|
in a dictionary to be reused and attached to other pairs.
|
||||||
@@ -612,7 +612,7 @@ class FreqaiDataKitchen:
|
|||||||
:return: corr_dataframes, dictionary of dataframes to be attached
|
:return: corr_dataframes, dictionary of dataframes to be attached
|
||||||
to other pairs in same candle.
|
to other pairs in same candle.
|
||||||
"""
|
"""
|
||||||
corr_dataframes: Dict[str, DataFrame] = {}
|
corr_dataframes: dict[str, DataFrame] = {}
|
||||||
pairs = self.freqai_config["feature_parameters"].get("include_corr_pairlist", [])
|
pairs = self.freqai_config["feature_parameters"].get("include_corr_pairlist", [])
|
||||||
|
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
@@ -628,7 +628,7 @@ class FreqaiDataKitchen:
|
|||||||
return corr_dataframes
|
return corr_dataframes
|
||||||
|
|
||||||
def attach_corr_pair_columns(
|
def attach_corr_pair_columns(
|
||||||
self, dataframe: DataFrame, corr_dataframes: Dict[str, DataFrame], current_pair: str
|
self, dataframe: DataFrame, corr_dataframes: dict[str, DataFrame], current_pair: str
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
"""
|
"""
|
||||||
Attach the existing corr_pair dataframes to the current pair dataframe before training
|
Attach the existing corr_pair dataframes to the current pair dataframe before training
|
||||||
@@ -731,7 +731,7 @@ class FreqaiDataKitchen:
|
|||||||
:param is_corr_pairs: bool = whether the pair is a corr pair or not
|
:param is_corr_pairs: bool = whether the pair is a corr pair or not
|
||||||
:return: dataframe = populated dataframe
|
:return: dataframe = populated dataframe
|
||||||
"""
|
"""
|
||||||
tfs: List[str] = self.freqai_config["feature_parameters"].get("include_timeframes")
|
tfs: list[str] = self.freqai_config["feature_parameters"].get("include_timeframes")
|
||||||
|
|
||||||
for tf in tfs:
|
for tf in tfs:
|
||||||
metadata = {"pair": pair, "tf": tf}
|
metadata = {"pair": pair, "tf": tf}
|
||||||
@@ -810,8 +810,8 @@ class FreqaiDataKitchen:
|
|||||||
f"{DOCS_LINK}/freqai-feature-engineering/"
|
f"{DOCS_LINK}/freqai-feature-engineering/"
|
||||||
)
|
)
|
||||||
|
|
||||||
tfs: List[str] = self.freqai_config["feature_parameters"].get("include_timeframes")
|
tfs: list[str] = self.freqai_config["feature_parameters"].get("include_timeframes")
|
||||||
pairs: List[str] = self.freqai_config["feature_parameters"].get("include_corr_pairlist", [])
|
pairs: list[str] = self.freqai_config["feature_parameters"].get("include_corr_pairlist", [])
|
||||||
|
|
||||||
for tf in tfs:
|
for tf in tfs:
|
||||||
if tf not in base_dataframes:
|
if tf not in base_dataframes:
|
||||||
@@ -828,7 +828,7 @@ class FreqaiDataKitchen:
|
|||||||
else:
|
else:
|
||||||
dataframe = base_dataframes[self.config["timeframe"]].copy()
|
dataframe = base_dataframes[self.config["timeframe"]].copy()
|
||||||
|
|
||||||
corr_pairs: List[str] = self.freqai_config["feature_parameters"].get(
|
corr_pairs: list[str] = self.freqai_config["feature_parameters"].get(
|
||||||
"include_corr_pairlist", []
|
"include_corr_pairlist", []
|
||||||
)
|
)
|
||||||
dataframe = self.populate_features(
|
dataframe = self.populate_features(
|
||||||
@@ -953,7 +953,7 @@ class FreqaiDataKitchen:
|
|||||||
Returns default FreqAI model path
|
Returns default FreqAI model path
|
||||||
:param config: Configuration dictionary
|
:param config: Configuration dictionary
|
||||||
"""
|
"""
|
||||||
freqai_config: Dict[str, Any] = config["freqai"]
|
freqai_config: dict[str, Any] = config["freqai"]
|
||||||
return Path(config["user_data_dir"] / "models" / str(freqai_config.get("identifier")))
|
return Path(config["user_data_dir"] / "models" / str(freqai_config.get("identifier")))
|
||||||
|
|
||||||
def remove_special_chars_from_feature_names(self, dataframe: pd.DataFrame) -> pd.DataFrame:
|
def remove_special_chars_from_feature_names(self, dataframe: pd.DataFrame) -> pd.DataFrame:
|
||||||
@@ -992,7 +992,7 @@ class FreqaiDataKitchen:
|
|||||||
return timerange
|
return timerange
|
||||||
|
|
||||||
# deprecated functions
|
# deprecated functions
|
||||||
def normalize_data(self, data_dictionary: Dict) -> Dict[Any, Any]:
|
def normalize_data(self, data_dictionary: dict) -> dict[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Deprecation warning, migration assistance
|
Deprecation warning, migration assistance
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import datasieve.transforms as ds
|
import datasieve.transforms as ds
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -59,11 +59,11 @@ class IFreqaiModel(ABC):
|
|||||||
def __init__(self, config: Config) -> None:
|
def __init__(self, config: Config) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.assert_config(self.config)
|
self.assert_config(self.config)
|
||||||
self.freqai_info: Dict[str, Any] = config["freqai"]
|
self.freqai_info: dict[str, Any] = config["freqai"]
|
||||||
self.data_split_parameters: Dict[str, Any] = config.get("freqai", {}).get(
|
self.data_split_parameters: dict[str, Any] = config.get("freqai", {}).get(
|
||||||
"data_split_parameters", {}
|
"data_split_parameters", {}
|
||||||
)
|
)
|
||||||
self.model_training_parameters: Dict[str, Any] = config.get("freqai", {}).get(
|
self.model_training_parameters: dict[str, Any] = config.get("freqai", {}).get(
|
||||||
"model_training_parameters", {}
|
"model_training_parameters", {}
|
||||||
)
|
)
|
||||||
self.identifier: str = self.freqai_info.get("identifier", "no_id_provided")
|
self.identifier: str = self.freqai_info.get("identifier", "no_id_provided")
|
||||||
@@ -80,14 +80,14 @@ class IFreqaiModel(ABC):
|
|||||||
self.dd.current_candle = self.current_candle
|
self.dd.current_candle = self.current_candle
|
||||||
self.scanning = False
|
self.scanning = False
|
||||||
self.ft_params = self.freqai_info["feature_parameters"]
|
self.ft_params = self.freqai_info["feature_parameters"]
|
||||||
self.corr_pairlist: List[str] = self.ft_params.get("include_corr_pairlist", [])
|
self.corr_pairlist: list[str] = self.ft_params.get("include_corr_pairlist", [])
|
||||||
self.keras: bool = self.freqai_info.get("keras", False)
|
self.keras: bool = self.freqai_info.get("keras", False)
|
||||||
if self.keras and self.ft_params.get("DI_threshold", 0):
|
if self.keras and self.ft_params.get("DI_threshold", 0):
|
||||||
self.ft_params["DI_threshold"] = 0
|
self.ft_params["DI_threshold"] = 0
|
||||||
logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
|
logger.warning("DI threshold is not configured for Keras models yet. Deactivating.")
|
||||||
|
|
||||||
self.CONV_WIDTH = self.freqai_info.get("conv_width", 1)
|
self.CONV_WIDTH = self.freqai_info.get("conv_width", 1)
|
||||||
self.class_names: List[str] = [] # used in classification subclasses
|
self.class_names: list[str] = [] # used in classification subclasses
|
||||||
self.pair_it = 0
|
self.pair_it = 0
|
||||||
self.pair_it_train = 0
|
self.pair_it_train = 0
|
||||||
self.total_pairs = len(self.config.get("exchange", {}).get("pair_whitelist"))
|
self.total_pairs = len(self.config.get("exchange", {}).get("pair_whitelist"))
|
||||||
@@ -99,13 +99,13 @@ class IFreqaiModel(ABC):
|
|||||||
self.base_tf_seconds = timeframe_to_seconds(self.config["timeframe"])
|
self.base_tf_seconds = timeframe_to_seconds(self.config["timeframe"])
|
||||||
self.continual_learning = self.freqai_info.get("continual_learning", False)
|
self.continual_learning = self.freqai_info.get("continual_learning", False)
|
||||||
self.plot_features = self.ft_params.get("plot_feature_importances", 0)
|
self.plot_features = self.ft_params.get("plot_feature_importances", 0)
|
||||||
self.corr_dataframes: Dict[str, DataFrame] = {}
|
self.corr_dataframes: dict[str, DataFrame] = {}
|
||||||
# get_corr_dataframes is controlling the caching of corr_dataframes
|
# get_corr_dataframes is controlling the caching of corr_dataframes
|
||||||
# for improved performance. Careful with this boolean.
|
# for improved performance. Careful with this boolean.
|
||||||
self.get_corr_dataframes: bool = True
|
self.get_corr_dataframes: bool = True
|
||||||
self._threads: List[threading.Thread] = []
|
self._threads: list[threading.Thread] = []
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self.metadata: Dict[str, Any] = self.dd.load_global_metadata_from_disk()
|
self.metadata: dict[str, Any] = self.dd.load_global_metadata_from_disk()
|
||||||
self.data_provider: Optional[DataProvider] = None
|
self.data_provider: Optional[DataProvider] = None
|
||||||
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
self.max_system_threads = max(int(psutil.cpu_count() * 2 - 2), 1)
|
||||||
self.can_short = True # overridden in start() with strategy.can_short
|
self.can_short = True # overridden in start() with strategy.can_short
|
||||||
@@ -901,7 +901,7 @@ class IFreqaiModel(ABC):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def update_metadata(self, metadata: Dict[str, Any]):
|
def update_metadata(self, metadata: dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Update global metadata and save the updated json file
|
Update global metadata and save the updated json file
|
||||||
:param metadata: new global metadata dict
|
:param metadata: new global metadata dict
|
||||||
@@ -954,7 +954,7 @@ class IFreqaiModel(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
Most regressors use the same function names and arguments e.g. user
|
Most regressors use the same function names and arguments e.g. user
|
||||||
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
||||||
@@ -968,7 +968,7 @@ class IFreqaiModel(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, NDArray[np.int_]]:
|
) -> tuple[DataFrame, NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from catboost import CatBoostClassifier, Pool
|
from catboost import CatBoostClassifier, Pool
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class CatboostClassifier(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from catboost import CatBoostClassifier, Pool
|
from catboost import CatBoostClassifier, Pool
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ class CatboostClassifierMultiTarget(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from catboost import CatBoostRegressor, Pool
|
from catboost import CatBoostRegressor, Pool
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class CatboostRegressor(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from catboost import CatBoostRegressor, Pool
|
from catboost import CatBoostRegressor, Pool
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from lightgbm import LGBMClassifier
|
from lightgbm import LGBMClassifier
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ class LightGBMClassifier(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from lightgbm import LGBMClassifier
|
from lightgbm import LGBMClassifier
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class LightGBMClassifierMultiTarget(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from lightgbm import LGBMRegressor
|
from lightgbm import LGBMRegressor
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ class LightGBMRegressor(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from lightgbm import LGBMRegressor
|
from lightgbm import LGBMRegressor
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -52,10 +52,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
config = self.freqai_info.get("model_training_parameters", {})
|
config = self.freqai_info.get("model_training_parameters", {})
|
||||||
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
||||||
self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {})
|
self.model_kwargs: dict[str, Any] = config.get("model_kwargs", {})
|
||||||
self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {})
|
self.trainer_kwargs: dict[str, Any] = config.get("trainer_kwargs", {})
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -51,10 +51,10 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
config = self.freqai_info.get("model_training_parameters", {})
|
config = self.freqai_info.get("model_training_parameters", {})
|
||||||
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
||||||
self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {})
|
self.model_kwargs: dict[str, Any] = config.get("model_kwargs", {})
|
||||||
self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {})
|
self.trainer_kwargs: dict[str, Any] = config.get("trainer_kwargs", {})
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -60,10 +60,10 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
config = self.freqai_info.get("model_training_parameters", {})
|
config = self.freqai_info.get("model_training_parameters", {})
|
||||||
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
self.learning_rate: float = config.get("learning_rate", 3e-4)
|
||||||
self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {})
|
self.model_kwargs: dict[str, Any] = config.get("model_kwargs", {})
|
||||||
self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {})
|
self.trainer_kwargs: dict[str, Any] = config.get("trainer_kwargs", {})
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
@@ -100,7 +100,7 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: pd.DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: pd.DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[pd.DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[pd.DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common.callbacks import ProgressBarCallback
|
from stable_baselines3.common.callbacks import ProgressBarCallback
|
||||||
@@ -44,7 +44,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
take fine-tuned control over the data handling pipeline.
|
take fine-tuned control over the data handling pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
def fit(self, data_dictionary: dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
|
||||||
"""
|
"""
|
||||||
User customizable fit method
|
User customizable fit method
|
||||||
:param data_dictionary: dict = common data dictionary containing all train/test
|
:param data_dictionary: dict = common data dictionary containing all train/test
|
||||||
@@ -77,7 +77,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
)
|
)
|
||||||
model = self.dd.model_dictionary[dk.pair]
|
model = self.dd.model_dictionary[dk.pair]
|
||||||
model.set_env(self.train_env)
|
model.set_env(self.train_env)
|
||||||
callbacks: List[Any] = [self.eval_callback, self.tensorboard_callback]
|
callbacks: list[Any] = [self.eval_callback, self.tensorboard_callback]
|
||||||
progressbar_callback: Optional[ProgressBarCallback] = None
|
progressbar_callback: Optional[ProgressBarCallback] = None
|
||||||
if self.rl_config.get("progress_bar", False):
|
if self.rl_config.get("progress_bar", False):
|
||||||
progressbar_callback = ProgressBarCallback()
|
progressbar_callback = ProgressBarCallback()
|
||||||
@@ -101,7 +101,7 @@ class ReinforcementLearner(BaseReinforcementLearningModel):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
MyRLEnv: Type[BaseEnvironment]
|
MyRLEnv: type[BaseEnvironment]
|
||||||
|
|
||||||
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
|
class MyRLEnv(Base5ActionRLEnv): # type: ignore[no-redef]
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
|
||||||
@@ -22,7 +22,7 @@ class ReinforcementLearner_multiproc(ReinforcementLearner):
|
|||||||
|
|
||||||
def set_train_and_eval_environments(
|
def set_train_and_eval_environments(
|
||||||
self,
|
self,
|
||||||
data_dictionary: Dict[str, Any],
|
data_dictionary: dict[str, Any],
|
||||||
prices_train: DataFrame,
|
prices_train: DataFrame,
|
||||||
prices_test: DataFrame,
|
prices_test: DataFrame,
|
||||||
dk: FreqaiDataKitchen,
|
dk: FreqaiDataKitchen,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -24,7 +24,7 @@ class SKLearnRandomForestClassifier(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
@@ -61,7 +61,7 @@ class SKLearnRandomForestClassifier(BaseClassifierModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -26,7 +26,7 @@ class XGBoostClassifier(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
@@ -64,7 +64,7 @@ class XGBoostClassifier(BaseClassifierModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -26,7 +26,7 @@ class XGBoostRFClassifier(BaseClassifierModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
@@ -64,7 +64,7 @@ class XGBoostRFClassifier(BaseClassifierModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from xgboost import XGBRFRegressor
|
from xgboost import XGBRFRegressor
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class XGBoostRFRegressor(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from xgboost import XGBRegressor
|
from xgboost import XGBRegressor
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class XGBoostRegressor(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from xgboost import XGBRegressor
|
from xgboost import XGBRegressor
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel):
|
|||||||
top level config.json file.
|
top level config.json file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
User sets up the training and test data to fit their desired model here
|
User sets up the training and test data to fit their desired model here
|
||||||
:param data_dictionary: the dictionary holding all data for train, test,
|
:param data_dictionary: the dictionary holding all data for train, test,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Type, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.logger import HParam
|
from stable_baselines3.common.logger import HParam
|
||||||
@@ -13,10 +13,10 @@ class TensorboardCallback(BaseCallback):
|
|||||||
episodic summary reports.
|
episodic summary reports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, verbose=1, actions: Type[Enum] = BaseActions):
|
def __init__(self, verbose=1, actions: type[Enum] = BaseActions):
|
||||||
super().__init__(verbose)
|
super().__init__(verbose)
|
||||||
self.model: Any = None
|
self.model: Any = None
|
||||||
self.actions: Type[Enum] = actions
|
self.actions: type[Enum] = actions
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
hparam_dict = {
|
hparam_dict = {
|
||||||
@@ -27,7 +27,7 @@ class TensorboardCallback(BaseCallback):
|
|||||||
# "batch_size": self.model.batch_size,
|
# "batch_size": self.model.batch_size,
|
||||||
# "n_steps": self.model.n_steps,
|
# "n_steps": self.model.n_steps,
|
||||||
}
|
}
|
||||||
metric_dict: Dict[str, Union[float, int]] = {
|
metric_dict: dict[str, Union[float, int]] = {
|
||||||
"eval/mean_reward": 0,
|
"eval/mean_reward": 0,
|
||||||
"rollout/ep_rew_mean": 0,
|
"rollout/ep_rew_mean": 0,
|
||||||
"rollout/ep_len_mean": 0,
|
"rollout/ep_len_mean": 0,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@@ -25,7 +25,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
criterion: nn.Module,
|
criterion: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
data_convertor: PyTorchDataConvertor,
|
data_convertor: PyTorchDataConvertor,
|
||||||
model_meta_data: Dict[str, Any] = {},
|
model_meta_data: dict[str, Any] = {},
|
||||||
window_size: int = 1,
|
window_size: int = 1,
|
||||||
tb_logger: Any = None,
|
tb_logger: Any = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -61,7 +61,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
self.tb_logger = tb_logger
|
self.tb_logger = tb_logger
|
||||||
self.test_batch_counter = 0
|
self.test_batch_counter = 0
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]):
|
def fit(self, data_dictionary: dict[str, pd.DataFrame], splits: list[str]):
|
||||||
"""
|
"""
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
all the training and test data/labels.
|
all the training and test data/labels.
|
||||||
@@ -102,7 +102,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_loss(
|
def estimate_loss(
|
||||||
self,
|
self,
|
||||||
data_loader_dictionary: Dict[str, DataLoader],
|
data_loader_dictionary: dict[str, DataLoader],
|
||||||
split: str,
|
split: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
@@ -119,8 +119,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
||||||
def create_data_loaders_dictionary(
|
def create_data_loaders_dictionary(
|
||||||
self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]
|
self, data_dictionary: dict[str, pd.DataFrame], splits: list[str]
|
||||||
) -> Dict[str, DataLoader]:
|
) -> dict[str, DataLoader]:
|
||||||
"""
|
"""
|
||||||
Converts the input data to PyTorch tensors using a data loader.
|
Converts the input data to PyTorch tensors using a data loader.
|
||||||
"""
|
"""
|
||||||
@@ -181,7 +181,7 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|||||||
checkpoint = torch.load(path)
|
checkpoint = torch.load(path)
|
||||||
return self.load_from_checkpoint(checkpoint)
|
return self.load_from_checkpoint(checkpoint)
|
||||||
|
|
||||||
def load_from_checkpoint(self, checkpoint: Dict):
|
def load_from_checkpoint(self, checkpoint: dict):
|
||||||
"""
|
"""
|
||||||
when using continual_learning, DataDrawer will load the dictionary
|
when using continual_learning, DataDrawer will load the dictionary
|
||||||
(containing state dicts and model_meta_data) by calling torch.load(path).
|
(containing state dicts and model_meta_data) by calling torch.load(path).
|
||||||
@@ -200,8 +200,8 @@ class PyTorchTransformerTrainer(PyTorchModelTrainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def create_data_loaders_dictionary(
|
def create_data_loaders_dictionary(
|
||||||
self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]
|
self, data_dictionary: dict[str, pd.DataFrame], splits: list[str]
|
||||||
) -> Dict[str, DataLoader]:
|
) -> dict[str, DataLoader]:
|
||||||
"""
|
"""
|
||||||
Converts the input data to PyTorch tensors using a data loader.
|
Converts the input data to PyTorch tensors using a data loader.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@@ -9,7 +8,7 @@ from torch import nn
|
|||||||
|
|
||||||
class PyTorchTrainerInterface(ABC):
|
class PyTorchTrainerInterface(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data_dictionary: Dict[str, pd.DataFrame], splits: List[str]) -> None:
|
def fit(self, data_dictionary: dict[str, pd.DataFrame], splits: list[str]) -> None:
|
||||||
"""
|
"""
|
||||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
all the training and test data/labels.
|
all the training and test data/labels.
|
||||||
@@ -41,7 +40,7 @@ class PyTorchTrainerInterface(ABC):
|
|||||||
return self.load_from_checkpoint(checkpoint)
|
return self.load_from_checkpoint(checkpoint)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_from_checkpoint(self, checkpoint: Dict) -> nn.Module:
|
def load_from_checkpoint(self, checkpoint: dict) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
when using continual_learning, DataDrawer will load the dictionary
|
when using continual_learning, DataDrawer will load the dictionary
|
||||||
(containing state dicts and model_meta_data) by calling torch.load(path).
|
(containing state dicts and model_meta_data) by calling torch.load(path).
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -155,7 +155,7 @@ def plot_feature_importance(
|
|||||||
store_plot_file(fig, f"{dk.model_filename}-{label}.html", dk.data_path)
|
store_plot_file(fig, f"{dk.model_filename}-{label}.html", dk.data_path)
|
||||||
|
|
||||||
|
|
||||||
def record_params(config: Dict[str, Any], full_path: Path) -> None:
|
def record_params(config: dict[str, Any], full_path: Path) -> None:
|
||||||
"""
|
"""
|
||||||
Records run params in the full path for reproducibility
|
Records run params in the full path for reproducibility
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user