From 5c9fcfe9ea2e320b0019c008aaad92c89a54936a Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 17:31:40 -0600 Subject: [PATCH] change check space according to compression setting --- air_llm/airllm/airllm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index ad50ae7..c1d6dc5 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -109,11 +109,16 @@ def load_layer(local_path, layer_name): -def check_space(checkpoint_path, layer_shards_saving_path=None): +def check_space(checkpoint_path, layer_shards_saving_path=None, compression=None): total_shard_files_size_bytes = 0 for model_shard_file in glob(str(checkpoint_path / '*')): total_shard_files_size_bytes += os.path.getsize(model_shard_file) + if compression == '4bit': + total_shard_files_size_bytes = total_shard_files_size_bytes // 4 + elif compression == '8bit': + total_shard_files_size_bytes = total_shard_files_size_bytes // 2 + total, used, free = shutil.disk_usage(checkpoint_path if layer_shards_saving_path is None else layer_shards_saving_path) if free < total_shard_files_size_bytes: @@ -165,7 +170,7 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt - check_space(checkpoint_path, layer_shards_saving_path) + check_space(checkpoint_path, layer_shards_saving_path, compression) with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f: index = json.load(f)['weight_map']