From 301f372fa0bcc9ae631c6036f7e53d97a3635bec Mon Sep 17 00:00:00 2001 From: Navodplayer1 Date: Sun, 18 Aug 2024 20:40:59 +0530 Subject: [PATCH] added delete_original support for single modelfiles --- air_llm/airllm/utils.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/air_llm/airllm/utils.py b/air_llm/airllm/utils.py index 5fe1890..e2536e9 100644 --- a/air_llm/airllm/utils.py +++ b/air_llm/airllm/utils.py @@ -247,22 +247,21 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt else: print(f"some layer splits found, some are not, re-save all layers in case there's some corruptions.") - if not delete_original: check_space(checkpoint_path, layer_shards_saving_path, compression, splitted_model_dir_name=splitted_model_dir_name) - shard = 0 n_shards = len(set(index.values())) state_dict = {} - if not os.path.exists(saving_path): #os.makedirs(saving_path) saving_path.mkdir(parents=True, exist_ok=True) + single_modelfile = None + for layer in tqdm(layers): # Optionnally load next shard @@ -300,8 +299,8 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt else: shards = [v for k, v in index.items() if k.startswith(layer)] - modelfile = shards[0] - to_load = checkpoint_path / modelfile + single_modelfile = shards[0] + to_load = checkpoint_path / single_modelfile # check if to_load exist, if not downloaad it... if not os.path.exists(to_load): assert repo_id is not None @@ -317,14 +316,12 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt layer_state_dict = compress_layer_state_dict(layer_state_dict, compression) - # Save layer state dict as using safetensors marker_exists = ModelPersister.get_model_persister().model_persist_exist(layer, saving_path) if not marker_exists: ModelPersister.get_model_persister().persist_model(layer_state_dict, layer, saving_path) - # Free memory for k in layer_state_dict.keys(): if k in state_dict: @@ -332,6 +329,13 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt del layer_state_dict clean_memory() + # deleting single modelfile if only a single modelfile was existing in hf repo + # and deletion of single modelfile should happen in the end if delete_original=True + if delete_original and single_modelfile != None: + to_delete = checkpoint_path / single_modelfile + print(f"deleting original file: {to_delete}") + remove_real_and_linked_file(to_delete) + return str(saving_path) def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None, compression=None,