mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
change check space according to compression setting
This commit is contained in:
@@ -109,11 +109,16 @@ def load_layer(local_path, layer_name):
|
||||
|
||||
|
||||
|
||||
def check_space(checkpoint_path, layer_shards_saving_path=None):
|
||||
def check_space(checkpoint_path, layer_shards_saving_path=None, compression=None):
|
||||
total_shard_files_size_bytes = 0
|
||||
for model_shard_file in glob(str(checkpoint_path / '*')):
|
||||
total_shard_files_size_bytes += os.path.getsize(model_shard_file)
|
||||
|
||||
if compression == '4bit':
|
||||
total_shard_files_size_bytes = total_shard_files_size_bytes // 4
|
||||
elif compression == '8bit':
|
||||
total_shard_files_size_bytes = total_shard_files_size_bytes // 2
|
||||
|
||||
total, used, free = shutil.disk_usage(checkpoint_path if layer_shards_saving_path is None else layer_shards_saving_path)
|
||||
|
||||
if free < total_shard_files_size_bytes:
|
||||
@@ -165,7 +170,7 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
|
||||
|
||||
|
||||
|
||||
check_space(checkpoint_path, layer_shards_saving_path)
|
||||
check_space(checkpoint_path, layer_shards_saving_path, compression)
|
||||
|
||||
with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f:
|
||||
index = json.load(f)['weight_map']
|
||||
|
||||
Reference in New Issue
Block a user