mirror of
https://github.com/freqtrade/freqtrade.git
synced 2025-11-29 08:33:07 +00:00
Merge pull request #9468 from freqtrade/freqai/fix_dump
Use cloudpickle to pickle freqai models
This commit is contained in:
@@ -12,7 +12,6 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import psutil
|
import psutil
|
||||||
import rapidjson
|
import rapidjson
|
||||||
from joblib import dump, load
|
|
||||||
from joblib.externals import cloudpickle
|
from joblib.externals import cloudpickle
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
@@ -471,7 +470,8 @@ class FreqaiDataDrawer:
|
|||||||
|
|
||||||
# Save the trained model
|
# Save the trained model
|
||||||
if self.model_type == 'joblib':
|
if self.model_type == 'joblib':
|
||||||
dump(model, save_path / f"{dk.model_filename}_model.joblib")
|
with (save_path / f"{dk.model_filename}_model.joblib").open("wb") as fp:
|
||||||
|
cloudpickle.dump(model, fp)
|
||||||
elif self.model_type == 'keras':
|
elif self.model_type == 'keras':
|
||||||
model.save(save_path / f"{dk.model_filename}_model.h5")
|
model.save(save_path / f"{dk.model_filename}_model.h5")
|
||||||
elif self.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]:
|
elif self.model_type in ["stable_baselines3", "sb3_contrib", "pytorch"]:
|
||||||
@@ -558,7 +558,8 @@ class FreqaiDataDrawer:
|
|||||||
if dk.live and coin in self.model_dictionary:
|
if dk.live and coin in self.model_dictionary:
|
||||||
model = self.model_dictionary[coin]
|
model = self.model_dictionary[coin]
|
||||||
elif self.model_type == 'joblib':
|
elif self.model_type == 'joblib':
|
||||||
model = load(dk.data_path / f"{dk.model_filename}_model.joblib")
|
with (dk.data_path / f"{dk.model_filename}_model.joblib").open("rb") as fp:
|
||||||
|
model = cloudpickle.load(fp)
|
||||||
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
|
elif 'stable_baselines' in self.model_type or 'sb3_contrib' == self.model_type:
|
||||||
mod = importlib.import_module(
|
mod = importlib.import_module(
|
||||||
self.model_type, self.freqai_info['rl_config']['model_type'])
|
self.model_type, self.freqai_info['rl_config']['model_type'])
|
||||||
|
|||||||
Reference in New Issue
Block a user