try limiting tb_logger to pytorch only (XGBoost still gets its callback)

This commit is contained in:
robcaulk
2023-05-14 12:03:15 +00:00
parent 8a9b2fc16f
commit ab7a474ab6

View File

@@ -635,9 +635,11 @@ class IFreqaiModel(ABC):
dk.find_features(unfiltered_dataframe)
dk.find_labels(unfiltered_dataframe)
self.tb_logger = TBLogger(dk.data_path)
if self.dd.model_type == "pytorch":
self.tb_logger = TBLogger(dk.data_path)
model = self.train(unfiltered_dataframe, pair, dk)
self.tb_logger.close()
if self.dd.model_type == "pytorch":
self.tb_logger.close()
self.dd.pair_dict[pair]["trained_timestamp"] = trained_timestamp
dk.set_new_model_names(pair, trained_timestamp)