mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
support safetensors model
This commit is contained in:
@@ -188,8 +188,15 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
|
||||
|
||||
check_space(checkpoint_path, layer_shards_saving_path, compression)
|
||||
|
||||
with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f:
|
||||
index = json.load(f)['weight_map']
|
||||
safetensors_format = False
|
||||
if os.path.exists(checkpoint_path / 'pytorch_model.bin.index.json'):
|
||||
with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f:
|
||||
index = json.load(f)['weight_map']
|
||||
else:
|
||||
safetensors_format = True
|
||||
assert os.path.exists(checkpoint_path / 'model.safetensors.index.json'), f'model.safetensors.index.json should exist.'
|
||||
with open(checkpoint_path / 'model.safetensors.index.json', 'rb') as f:
|
||||
index = json.load(f)['weight_map']
|
||||
|
||||
n_layers = len(set([int(k.split('.')[2]) for k in index.keys() if 'model.layers' in k]))
|
||||
layers = ['model.embed_tokens.'] + [f'model.layers.{i}.' for i in range(n_layers)] + ['model.norm.', 'lm_head.']
|
||||
@@ -209,8 +216,12 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
|
||||
if max(shards) > shard:
|
||||
shard += 1
|
||||
print(f'Loading shard {shard}/{n_shards}')
|
||||
state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin',
|
||||
map_location='cpu'))
|
||||
if not safetensors_format:
|
||||
state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin',
|
||||
map_location='cpu'))
|
||||
else:
|
||||
state_dict.update(load_file(checkpoint_path / f'model-000{shard:02d}-of-000{n_shards:02d}.safetensors',
|
||||
map_location='cpu'))
|
||||
|
||||
# Get layer state dict
|
||||
layer_state_dict = dict([(k, v) for k, v in state_dict.items() if k.startswith(layer)])
|
||||
@@ -255,7 +266,8 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
|
||||
|
||||
# try local model path, if the model exist split and save there
|
||||
if os.path.exists(model_local_path_or_repo_id):
|
||||
if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json'):
|
||||
if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json') or \
|
||||
os.path.exists(Path(model_local_path_or_repo_id) / 'model.safetensors.index.json'):
|
||||
return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path, compression=compression)
|
||||
else:
|
||||
print(
|
||||
@@ -263,8 +275,9 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
|
||||
|
||||
# it should be a repo id at this point...
|
||||
hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id)
|
||||
assert os.path.exists(Path(
|
||||
hf_cache_path) / 'pytorch_model.bin.index.json'), f"{hf_cache_path}/pytorch_model.bin.index.json should exists."
|
||||
assert os.path.exists(Path(hf_cache_path) / 'pytorch_model.bin.index.json') or \
|
||||
os.path.exists(Path(hf_cache_path) / 'model.safetensors.index.json'), \
|
||||
f"{hf_cache_path}/pytorch_model.bin.index.json or {hf_cache_path}/model.safetensors.index.json should exists."
|
||||
|
||||
# if splitted_model subdir exists under cache use it, otherwise split and save
|
||||
return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path, compression=compression)
|
||||
|
||||
Reference in New Issue
Block a user