diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 5978e94..72fcbca 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -188,8 +188,15 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt check_space(checkpoint_path, layer_shards_saving_path, compression) - with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f: - index = json.load(f)['weight_map'] + safetensors_format = False + if os.path.exists(checkpoint_path / 'pytorch_model.bin.index.json'): + with open(checkpoint_path / 'pytorch_model.bin.index.json', 'rb') as f: + index = json.load(f)['weight_map'] + else: + safetensors_format = True + assert os.path.exists(checkpoint_path / 'model.safetensors.index.json'), f'model.safetensors.index.json should exist.' + with open(checkpoint_path / 'model.safetensors.index.json', 'rb') as f: + index = json.load(f)['weight_map'] n_layers = len(set([int(k.split('.')[2]) for k in index.keys() if 'model.layers' in k])) layers = ['model.embed_tokens.'] + [f'model.layers.{i}.' for i in range(n_layers)] + ['model.norm.', 'lm_head.'] @@ -209,8 +216,12 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt if max(shards) > shard: shard += 1 print(f'Loading shard {shard}/{n_shards}') - state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin', - map_location='cpu')) + if not safetensors_format: + state_dict.update(torch.load(checkpoint_path / f'pytorch_model-000{shard:02d}-of-000{n_shards:02d}.bin', + map_location='cpu')) + else: + state_dict.update(load_file(checkpoint_path / f'model-000{shard:02d}-of-000{n_shards:02d}.safetensors', + map_location='cpu')) # Get layer state dict layer_state_dict = dict([(k, v) for k, v in state_dict.items() if k.startswith(layer)]) @@ -255,7 +266,8 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards # try local model path, if the model exist split and save there if os.path.exists(model_local_path_or_repo_id): - if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json'): + if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json') or \ + os.path.exists(Path(model_local_path_or_repo_id) / 'model.safetensors.index.json'): return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path, compression=compression) else: print( @@ -263,8 +275,9 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards # it should be a repo id at this point... hf_cache_path = huggingface_hub.snapshot_download(model_local_path_or_repo_id) - assert os.path.exists(Path( - hf_cache_path) / 'pytorch_model.bin.index.json'), f"{hf_cache_path}/pytorch_model.bin.index.json should exists." + assert os.path.exists(Path(hf_cache_path) / 'pytorch_model.bin.index.json') or \ + os.path.exists(Path(hf_cache_path) / 'model.safetensors.index.json'), \ + f"{hf_cache_path}/pytorch_model.bin.index.json or {hf_cache_path}/model.safetensors.index.json should exists." # if splitted_model subdir exists under cache use it, otherwise split and save return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path, compression=compression)