make sure its on cuda

This commit is contained in:
Yu Li
2023-12-01 17:50:34 -06:00
parent 750a6908bd
commit 290cfa1122

View File

@@ -209,9 +209,10 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
# Free memory
for k in layer_state_dict.keys():
del state_dict[k]
if k in state_dict:
del state_dict[k]
del layer_state_dict
gc.collect()
clean_memory()
return str(saving_path)