support safetensors model

This commit is contained in:
Yu Li
2023-12-01 21:53:30 -06:00
parent b0999223d0
commit 9e6557aeee

View File

@@ -188,8 +188,15 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
check_space(checkpoint_path, layer_shards_saving_path, compression)
with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f:
index = json.load(f)['weight_map']
safetensors_format = False
if os.path.exists(checkpoint_path / 'pytorch_model.bin.index.json'):
with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f:
index = json.load(f)['weight_map']
else:
safetensors_format = True
assert os.path.exists(checkpoint_path / 'model.safetensors.index.json'), f'model.safetensors.index.json should exist.'
with open(checkpoint_path / 'model.safetensors.index.json', 'rb') as f:
index = json.load(f)['weight_map']
n_layers = len(set([int(k.split('.')[2]) for k in index.keys() if 'model.layers' in k]))
layers = ['model.embed_tokens.'] + [f'model.layers.{i}.' for i in range(n_layers)] + ['model.norm.', 'lm_head.']
@@ -209,8 +216,12 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
if max(shards) > shard:
shard += 1
print(f'Loading shard {shard}/{n_shards}')
state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin',
map_location='cpu'))
if not safetensors_format:
state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin',
map_location='cpu'))
else:
state_dict.update(load_file(checkpoint_path / f'model-000{shard:02d}-of-000{n_shards:02d}.safetensors',
map_location='cpu'))
# Get layer state dict
layer_state_dict = dict([(k, v) for k, v in state_dict.items() if k.startswith(layer)])
@@ -255,7 +266,8 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
# try local model path, if the model exist split and save there
if os.path.exists(model_local_path_or_repo_id):
if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json'):
if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json') or \
os.path.exists(Path(model_local_path_or_repo_id) / 'model.safetensors.index.json'):
return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path, compression=compression)
else:
print(
@@ -263,8 +275,9 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
# it should be a repo id at this point...
hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id)
assert os.path.exists(Path(
hf_cache_path) / 'pytorch_model.bin.index.json'), f"{hf_cache_path}/pytorch_model.bin.index.json should exists."
assert os.path.exists(Path(hf_cache_path) / 'pytorch_model.bin.index.json') or \
os.path.exists(Path(hf_cache_path) / 'model.safetensors.index.json'), \
f"{hf_cache_path}/pytorch_model.bin.index.json or {hf_cache_path}/model.safetensors.index.json should exists."
# if splitted_model subdir exists under cache use it, otherwise split and save
return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path, compression=compression)