Merge pull request #9468 from freqtrade/freqai/fix_dump

Use cloudpickle to pickle freqai models
This commit is contained in:
Robert Caulk
2023-12-02 18:51:31 +01:00
committed by GitHub

View File

@@ -12,7 +12,6 @@ import numpy as np
import pandas as pd
import psutil
import rapidjson
from joblib import dump, load
from joblib.externals import cloudpickle
from numpy.typing import NDArray
from pandas import DataFrame
@@ -471,7 +470,8 @@ class FreqaiDataDrawer:
# Save the trained model
if self.model_type == 'joblib':
dump(model, save_path / f"{dk.model_filename}_model.joblib")
with (save_path / f"{dk.model_filename}_model.joblib").open("wb") as fp:
cloudpickle.dump(model, fp)
elif self.model_type == 'keras':
model.save(save_path / f"{dk.model_filename}_model.h5")
elif self.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]:
@@ -558,7 +558,8 @@ class FreqaiDataDrawer:
if dk.live and coin in self.model_dictionary:
model = self.model_dictionary[coin]
elif self.model_type == 'joblib':
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
with (dk.data_path / f"{dk.model_filename}_model.joblib").open("rb") as fp:
model = cloudpickle.load(fp)
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
mod = importlib.import_module(
self.model_type, self.freqai_info['rl_config']['model_type'])