Merge pull request #9468 from freqtrade/freqai/fix_dump

Use cloudpickle to pickle freqai models
This commit is contained in:
Robert Caulk
2023-12-02 18:51:31 +01:00
committed by GitHub

View File

@@ -12,7 +12,6 @@ import numpy as np
import pandas as pd import pandas as pd
import psutil import psutil
import rapidjson import rapidjson
from joblib import dump, load
from joblib.externals import cloudpickle from joblib.externals import cloudpickle
from numpy.typing import NDArray from numpy.typing import NDArray
from pandas import DataFrame from pandas import DataFrame
@@ -471,7 +470,8 @@ class FreqaiDataDrawer:
# Save the trained model # Save the trained model
if self.model_type == 'joblib': if self.model_type == 'joblib':
dump(model, save_path / f"{dk.model_filename}_model.joblib") with (save_path / f"{dk.model_filename}_model.joblib").open("wb") as fp:
cloudpickle.dump(model, fp)
elif self.model_type == 'keras': elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5") model.save(save_path / f"{dk.model_filename}_model.h5")
elif self.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]: elif self.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]:
@@ -558,7 +558,8 @@ class FreqaiDataDrawer:
if dk.live and coin in self.model_dictionary: if dk.live and coin in self.model_dictionary:
model = self.model_dictionary[coin] model = self.model_dictionary[coin]
elif self.model_type == 'joblib': elif self.model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib") with (dk.data_path / f"{dk.model_filename}_model.joblib").open("rb") as fp:
model = cloudpickle.load(fp)
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type: elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
mod = importlib.import_module( mod = importlib.import_module(
self.model_type, self.freqai_info['rl_config']['model_type']) self.model_type, self.freqai_info['rl_config']['model_type'])