support delete_original model files

This commit is contained in:
Yu Li
2023-12-27 08:55:23 -06:00
parent 87b139fbc8
commit c8a87f9ad9
5 changed files with 66 additions and 21 deletions

View File

@@ -156,6 +156,7 @@ When initialize the model, we support the following configurations:
* **layer_shards_saving_path**: optionally another path to save the splitted model
* **hf_token**: huggingface token can be provided here if downloading gated models like: *meta-llama/Llama-2-7b-hf*
* **prefetching**: prefetching to overlap the model loading and compute. By default turned on. For now only AirLLMLlama2 supports this.
* **delete_original**: if you don't have too much disk space, you can set delete_original to true to delete the original downloaded hugging face model, only keep the transformed one to save half of the disk space.
## MacOS

View File

@@ -54,7 +54,7 @@ class AirLLMBaseModel(GenerationMixin):
def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512,
layer_shards_saving_path=None, profiling_mode=False, compression=None,
hf_token=None, prefetching=True):
hf_token=None, prefetching=True, delete_original=False):
"""
Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
@@ -105,7 +105,8 @@ class AirLLMBaseModel(GenerationMixin):
layer_shards_saving_path,
compression=compression,
layer_names=self.layer_names_dict,
hf_token=hf_token)
hf_token=hf_token,
delete_original=delete_original)
self.running_device = device
self.device = torch.device(self.running_device)
self.running_dtype = dtype

View File

@@ -202,17 +202,22 @@ class AirLLMLlamaMlx:
else:
self.least_available = min(available, self.least_available)
print(f"[{msg}] - available mem: {available:.02}mb, least available:{available:.02}mb")
consumed = self.initial_available - available
max_consumed = self.initial_available - self.least_available
print(f"[{msg}] - available mem: {available:.02f}mb, consumed: {consumed:.02f}mb, least available:{available:.02f}mb, max consumed: {max_consumed:.02f}mb")
def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=None, max_seq_len=512,
layer_shards_saving_path=None, profiling_mode=False, compression=None,
hf_token=None, prefetching=True, test_nonlayered=False, show_memory_util=False):
hf_token=None, prefetching=True, test_nonlayered=False, show_memory_util=False,
delete_original=False):
self.hf_token = hf_token
self.set_layer_names_dict()
self.test_nonlayered = test_nonlayered
self.show_memory_util = show_memory_util
self.least_available = None
self.initial_available = psutil.virtual_memory().available / 1024 / 1024
@@ -220,7 +225,8 @@ class AirLLMLlamaMlx:
layer_shards_saving_path,
compression=compression,
layer_names=self.layer_names_dict,
hf_token=hf_token)
hf_token=hf_token,
delete_original=delete_original)
if hf_token is not None:
self.config = AutoConfig.from_pretrained(self.model_local_path, token=hf_token, trust_remote_code=True)
else:

View File

@@ -175,9 +175,18 @@ def compress_layer_state_dict(layer_state_dict, compression=None):
return compressed_layer_state_dict if compressed_layer_state_dict is not None else layer_state_dict
def remove_real_and_linked_file(to_delete):
if (os.path.realpath(to_delete) != to_delete):
targetpath = os.path.realpath(to_delete)
os.remove(to_delete)
if (targetpath):
os.remove(targetpath)
def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitted_model_dir_name='splitted_model',
compression=None, layer_names=None):
compression=None, layer_names=None, delete_original=False, repo_id=None, hf_token=None):
"""
Save the all layers of a model sharded checkpoint using safetensors.
"""
@@ -230,7 +239,7 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
for layer in layers:
found_layers[layer] = ModelPersister.get_model_persister().model_persist_exist(layer, saving_path)
print(f"found_layers:{found_layers}")
if all(found_layers.values()):
# already downloaded, return saving path...
print(f"saved layers already found in {saving_path}")
@@ -239,8 +248,8 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
print(f"some layer splits found, some are not, re-save all layers in case there's some corruptions.")
check_space(checkpoint_path, layer_shards_saving_path, compression, splitted_model_dir_name=splitted_model_dir_name)
if not delete_original:
check_space(checkpoint_path, layer_shards_saving_path, compression, splitted_model_dir_name=splitted_model_dir_name)
@@ -259,14 +268,34 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
# Optionnally load next shard
shards = [int(v.split('-')[1]) for k, v in index.items() if k.startswith(layer)]
if max(shards) > shard:
# optinoally delete original file
if delete_original and shard != 0:
if not safetensors_format:
to_delete = checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin'
else:
to_delete = checkpoint_path / f'model-000{shard:02d}-of-000{n_shards:02d}.safetensors'
print(f"deleting original file: {to_delete}")
remove_real_and_linked_file(to_delete)
shard += 1
print(f'Loading shard {shard}/{n_shards}')
if not safetensors_format:
state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin',
map_location='cpu'))
to_load = checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin'
else:
state_dict.update(load_file(checkpoint_path / f'model-000{shard:02d}-of-000{n_shards:02d}.safetensors',
device='cpu'))
to_load = checkpoint_path / f'model-000{shard:02d}-of-000{n_shards:02d}.safetensors'
# check if to_load exist, if not downloaad it...
if not os.path.exists(to_load):
assert repo_id is not None
huggingface_hub.snapshot_download(repo_id, allow_patterns=os.path.basename(to_load),
token=hf_token)
if not safetensors_format:
state_dict.update(torch.load(to_load, map_location='cpu'))
else:
state_dict.update(load_file(to_load, device='cpu'))
# Get layer state dict
layer_state_dict = dict([(k, v) for k, v in state_dict.items() if k.startswith(layer)])
@@ -291,7 +320,7 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
return str(saving_path)
def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None, compression=None,
layer_names=None, hf_token=None):
layer_names=None, hf_token=None, delete_original=False):
"""
find the model's local cache path, download the cache if not exists, then split and save the model.
@@ -318,17 +347,23 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
if os.path.exists(model_local_path_or_repo_id):
if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json') or \
os.path.exists(Path(model_local_path_or_repo_id) / 'model.safetensors.index.json'):
print(f"found index file...")
return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path,
compression=compression, layer_names=layer_names)
compression=compression, layer_names=layer_names, delete_original=delete_original)
else:
print(
f"Found local directory in {model_local_path_or_repo_id}, but didn't find downloaded model. Try using {model_local_path_or_repo_id} as a HF repo...")
# it should be a repo id at this point...
hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id, token=hf_token,
#allow_patterns= ["model.safetensors.index.json", 'pytorch_model.bin.index.json'],
ignore_patterns=['*.safetensors', '*.bin'])
# check if there's safetensors saved, if so, exclude torch saves
# delay download now...
'''
hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id, token=hf_token, allow_patterns="model.safetensors.index.json")
if len(glob(str(Path(hf_cache_path) / "model.safetensors.index.json"))) > 0:
# there's safe tensor version, exclude torch version
hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id, token=hf_token,
@@ -337,11 +372,13 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
else:
hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id,
token=hf_token)
'''
assert os.path.exists(Path(hf_cache_path) / 'pytorch_model.bin.index.json') or \
os.path.exists(Path(hf_cache_path) / 'model.safetensors.index.json'), \
f"{hf_cache_path}/pytorch_model.bin.index.json or {hf_cache_path}/model.safetensors.index.json should exists."
#assert os.path.exists(Path(hf_cache_path) / 'pytorch_model.bin.index.json') or \
# os.path.exists(Path(hf_cache_path) / 'model.safetensors.index.json'), \
# f"{hf_cache_path}/pytorch_model.bin.index.json or {hf_cache_path}/model.safetensors.index.json should exists."
# if splitted_model subdir exists under cache use it, otherwise split and save
return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path,
compression=compression, layer_names=layer_names)
compression=compression, layer_names=layer_names,
delete_original=delete_original, repo_id=model_local_path_or_repo_id, hf_token=hf_token)

View File

@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup(
name="airllm",
version="2.8.2",
version="2.8.3",
author="Gavin Li",
author_email="gavinli@animaai.cloud",
description="AirLLM allows single 4GB GPU card to run 70B large language models without quantization, distillation or pruning.",