Merge pull request #173 from NavodPeiris/main

added delete_original support for single modelfiles
This commit is contained in:
Gavin Li
2024-08-18 11:43:00 -05:00
committed by GitHub

View File

@@ -247,22 +247,21 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
else:
print(f"some layer splits found, some are not, re-save all layers in case there's some corruptions.")
if not delete_original:
check_space(checkpoint_path, layer_shards_saving_path, compression, splitted_model_dir_name=splitted_model_dir_name)
shard = 0
n_shards = len(set(index.values()))
state_dict = {}
if not os.path.exists(saving_path):
#os.makedirs(saving_path)
saving_path.mkdir(parents=True, exist_ok=True)
single_modelfile = None
for layer in tqdm(layers):
# Optionnally load next shard
@@ -300,8 +299,8 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
else:
shards = [v for k, v in index.items() if k.startswith(layer)]
modelfile = shards[0]
to_load = checkpoint_path / modelfile
single_modelfile = shards[0]
to_load = checkpoint_path / single_modelfile
# check if to_load exist, if not downloaad it...
if not os.path.exists(to_load):
assert repo_id is not None
@@ -317,14 +316,12 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
layer_state_dict = compress_layer_state_dict(layer_state_dict, compression)
# Save layer state dict as using safetensors
marker_exists = ModelPersister.get_model_persister().model_persist_exist(layer, saving_path)
if not marker_exists:
ModelPersister.get_model_persister().persist_model(layer_state_dict, layer, saving_path)
# Free memory
for k in layer_state_dict.keys():
if k in state_dict:
@@ -332,6 +329,13 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
del layer_state_dict
clean_memory()
# deleting single modelfile if only a single modelfile was existing in hf repo
# and deletion of single modelfile should happen in the end if delete_original=True
if delete_original and single_modelfile != None:
to_delete = checkpoint_path / single_modelfile
print(f"deleting original file: {to_delete}")
remove_real_and_linked_file(to_delete)
return str(saving_path)
def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None, compression=None,