mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
profiling and compression
This commit is contained in:
@@ -7,6 +7,7 @@ import shutil
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from glob import glob
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationMixin, LlamaForCausalLM, GenerationConfig
|
||||
@@ -17,6 +18,51 @@ from safetensors.torch import load_file, save_file
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
import huggingface_hub
|
||||
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
bitsandbytes_installed = True
|
||||
print('>>>> bitsandbytes installed')
|
||||
except ImportError:
|
||||
bitsandbytes_installed = False
|
||||
|
||||
total_disk_loading_time = None
|
||||
total_gpu_loading_time = None
|
||||
|
||||
|
||||
# replacement for bnb quantstat.as_dict(True), until the bug is fixed....
|
||||
def save_quant_state_to_dict(self, packed=True):
|
||||
"""
|
||||
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
|
||||
param: packed -- returns dict[str, torch.Tensor] for state_dict
|
||||
"""
|
||||
qs_dict = {
|
||||
'quant_type': self.quant_type,
|
||||
'absmax': self.absmax,
|
||||
'blocksize': self.blocksize,
|
||||
'quant_map': self.code,
|
||||
'dtype': str(self.dtype).strip('torch.'),
|
||||
'shape': tuple(self.shape),
|
||||
}
|
||||
if self.nested:
|
||||
qs_dict.update({
|
||||
'nested_absmax': self.state2.absmax,
|
||||
'nested_blocksize': self.state2.blocksize,
|
||||
'nested_quant_map': self.state2.code,
|
||||
'nested_dtype': str(self.state2.dtype).strip('torch.'),
|
||||
'nested_offset': self.offset.item(),
|
||||
})
|
||||
if not packed:
|
||||
return qs_dict
|
||||
|
||||
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
|
||||
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
|
||||
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = bnb.utils.pack_dict_to_tensor(non_tensor_dict)
|
||||
return qs_packed_dict
|
||||
|
||||
|
||||
|
||||
class NotEnoughSpaceException(Exception):
|
||||
pass
|
||||
|
||||
@@ -30,9 +76,37 @@ def clean_memory():
|
||||
pass
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def uncompress_layer_state_dict(layer_state_dict):
|
||||
uncompressed_layer_state_dict = None
|
||||
if any(['4bit' in k for k in layer_state_dict.keys()]):
|
||||
uncompressed_layer_state_dict = {}
|
||||
for k, v in layer_state_dict.items():
|
||||
if '4bit' not in k:
|
||||
quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk}
|
||||
quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cpu")
|
||||
|
||||
dqv = bnb.functional.dequantize_nf4(v, quant_state)
|
||||
uncompressed_layer_state_dict[k] = dqv
|
||||
del layer_state_dict
|
||||
elif any(['8bit' in k for k in layer_state_dict.keys()]):
|
||||
uncompressed_layer_state_dict = {}
|
||||
for k, v in layer_state_dict.items():
|
||||
if '8bit' not in k:
|
||||
|
||||
absmax = layer_state_dict[k + ".8bit.absmax"]
|
||||
code = layer_state_dict[k + ".8bit.code"]
|
||||
|
||||
dqv = bnb.functional.dequantize(v, (absmax, code))
|
||||
uncompressed_layer_state_dict[k] = dqv
|
||||
del layer_state_dict
|
||||
|
||||
return layer_state_dict if uncompressed_layer_state_dict is None else uncompressed_layer_state_dict
|
||||
|
||||
def load_layer(local_path, layer_name):
|
||||
layer_state_dict = load_file(Path(local_path) / (layer_name + ".safetensors"), device="cpu")
|
||||
return layer_state_dict
|
||||
return uncompress_layer_state_dict(layer_state_dict)
|
||||
|
||||
|
||||
|
||||
def check_space(checkpoint_path, layer_shards_saving_path=None):
|
||||
@@ -46,12 +120,35 @@ def check_space(checkpoint_path, layer_shards_saving_path=None):
|
||||
raise NotEnoughSpaceException(f"Not enough space. Free space under {checkpoint_path if layer_shards_saving_path is None else layer_shards_saving_path}:" \
|
||||
f" {free / 1024 / 1024 / 1024:.02f}GB. Model total size: {total_shard_files_size_bytes / 1024 / 1024 / 1024:.02f}GB.")
|
||||
|
||||
def compress_layer_state_dict(layer_state_dict, compression=None):
|
||||
compressed_layer_state_dict = None
|
||||
if compression == '4bit':
|
||||
compressed_layer_state_dict = {}
|
||||
for k, v in layer_state_dict.items():
|
||||
v_quant, quant_state = bnb.functional.quantize_nf4(v, blocksize=64)
|
||||
compressed_layer_state_dict[k] = v_quant
|
||||
for quant_state_k, quant_state_v in save_quant_state_to_dict(quant_state).items():
|
||||
compressed_layer_state_dict[k + ".4bit." + quant_state_k] = quant_state_v
|
||||
elif compression == '8bit':
|
||||
compressed_layer_state_dict = {}
|
||||
for k, v in layer_state_dict.items():
|
||||
v_quant, (absmax, code) = bnb.functional.quantize(v)
|
||||
compressed_layer_state_dict[k] = v_quant
|
||||
compressed_layer_state_dict[k + ".8bit.absmax"] = absmax
|
||||
compressed_layer_state_dict[k + ".8bit.code"] = code
|
||||
|
||||
def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitted_model_dir_name='splitted_model'):
|
||||
return compressed_layer_state_dict if compressed_layer_state_dict is not None else layer_state_dict
|
||||
|
||||
|
||||
def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitted_model_dir_name='splitted_model', compression=None):
|
||||
"""
|
||||
Save the all layers of a model sharded checkpoint using safetensors.
|
||||
"""
|
||||
|
||||
if compression is not None:
|
||||
assert bitsandbytes_installed, f"when using compression bitsandbytes has to be installed."
|
||||
splitted_model_dir_name = splitted_model_dir_name + "." + compression
|
||||
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
|
||||
|
||||
@@ -95,6 +192,9 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
|
||||
# Get layer state dict
|
||||
layer_state_dict = dict([(k, v) for k, v in state_dict.items() if k.startswith(layer)])
|
||||
|
||||
layer_state_dict = compress_layer_state_dict(layer_state_dict, compression)
|
||||
|
||||
|
||||
# Save layer state dict as using safetensors
|
||||
save_file(layer_state_dict, saving_path / (layer + 'safetensors'))
|
||||
|
||||
@@ -108,7 +208,7 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt
|
||||
|
||||
return str(saving_path)
|
||||
|
||||
def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None):
|
||||
def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None, compression=None):
|
||||
"""
|
||||
find the model's local cache path, download the cache if not exists, then split and save the model.
|
||||
|
||||
@@ -125,12 +225,14 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
|
||||
local model path
|
||||
saved_layer_shards_path : str
|
||||
the path saved layer shards
|
||||
compression: str, optinal
|
||||
setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss.
|
||||
"""
|
||||
|
||||
# try local model path, if the model exist split and save there
|
||||
if os.path.exists(model_local_path_or_repo_id):
|
||||
if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json'):
|
||||
return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path)
|
||||
return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path, compression=compression)
|
||||
else:
|
||||
print(
|
||||
f"Found local directory in {model_local_path_or_repo_id}, but didn't find downloaded model. Try using {model_local_path_or_repo_id} as a HF repo...")
|
||||
@@ -141,13 +243,13 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards
|
||||
hf_cache_path) / 'pytorch_model.bin.index.json'), f"{hf_cache_path}/pytorch_model.bin.index.json should exists."
|
||||
|
||||
# if splitted_model subdir exists under cache use it, otherwise split and save
|
||||
return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path)
|
||||
return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path, compression=compression)
|
||||
|
||||
|
||||
|
||||
class AirLLMLlama2(GenerationMixin):
|
||||
def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512,
|
||||
layer_shards_saving_path=None):
|
||||
layer_shards_saving_path=None, profiling_mode=False, compression=None):
|
||||
"""
|
||||
Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
|
||||
During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
|
||||
@@ -165,10 +267,27 @@ class AirLLMLlama2(GenerationMixin):
|
||||
max seq lenght, by default 512
|
||||
layer_shards_saving_path : str, optional
|
||||
optional path to save layered shards model file, by default just save to the local cache of model, subdir named splitted_model will be saved
|
||||
profiling_mode : book, optional
|
||||
if to profile the model loading time, default to False
|
||||
compression: str, optinal
|
||||
setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss.
|
||||
"""
|
||||
|
||||
|
||||
if profiling_mode:
|
||||
self.profiling_mode = profiling_mode
|
||||
|
||||
if compression is not None:
|
||||
if not bitsandbytes_installed:
|
||||
raise ImportError('WARNING: bitsandbytes not found. Compression needs bitsandbytes. To use compression, please install bitsandbytes: `pip install bitsandbytes`')
|
||||
|
||||
|
||||
self.compression = compression
|
||||
|
||||
# Save parameters
|
||||
self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path)
|
||||
self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id,
|
||||
layer_shards_saving_path,
|
||||
compression=compression)
|
||||
self.running_device = device
|
||||
self.device = torch.device(self.running_device)
|
||||
self.running_dtype = dtype
|
||||
@@ -277,6 +396,12 @@ class AirLLMLlama2(GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
global total_disk_loading_time, total_gpu_loading_time
|
||||
|
||||
if self.profiling_mode:
|
||||
total_disk_loading_time = []
|
||||
total_gpu_loading_time = []
|
||||
|
||||
# Reboot the model to make sure buffers are loaded and memory is clean
|
||||
del self.model
|
||||
clean_memory()
|
||||
@@ -304,8 +429,19 @@ class AirLLMLlama2(GenerationMixin):
|
||||
for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.running_device,
|
||||
total=len(self.layers)):
|
||||
|
||||
t = time.process_time()
|
||||
state_dict = self.load_layer_to_cpu(layer_name)
|
||||
elapsed_time = time.process_time() - t
|
||||
# profile
|
||||
if self.profiling_mode:
|
||||
total_disk_loading_time.append(elapsed_time)
|
||||
|
||||
t = time.process_time()
|
||||
self.move_layer_to_device(state_dict)
|
||||
elapsed_time = time.process_time() - t
|
||||
# profile
|
||||
if self.profiling_mode:
|
||||
total_gpu_loading_time.append(elapsed_time)
|
||||
|
||||
# Run layer
|
||||
|
||||
@@ -403,6 +539,13 @@ class AirLLMLlama2(GenerationMixin):
|
||||
tuple(kv_cache_list) if kv_cache_list is not None else None,
|
||||
tuple(all_hidden_states) if all_hidden_states is not None else None,
|
||||
tuple(all_self_attns) if all_self_attns is not None else None] if v is not None)
|
||||
if self.profiling_mode:
|
||||
print(f"total disk loading time: {sum(total_disk_loading_time):.04f}")
|
||||
print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}")
|
||||
|
||||
total_disk_loading_time = []
|
||||
total_gpu_loading_time = []
|
||||
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
|
||||
@@ -21,6 +21,8 @@ setuptools.setup(
|
||||
'safetensors',
|
||||
'optimum',
|
||||
'huggingface_hub'
|
||||
'scipy',
|
||||
#'bitsandbytes' set it to optional to support fallback when not installable
|
||||
],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
||||
40
air_llm/tests/test_compression.py
Normal file
40
air_llm/tests/test_compression.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
sys.path.insert(0, '../airllm')
|
||||
|
||||
from airllm import compress_layer_state_dict, uncompress_layer_state_dict
|
||||
|
||||
|
||||
|
||||
|
||||
class TestCompression(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_should_compress_uncompress(self):
|
||||
torch.manual_seed(0)
|
||||
a0 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda()
|
||||
a1 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda()
|
||||
|
||||
a_state_dict = {'a0':a0, 'a1':a1}
|
||||
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
|
||||
for compression in [None, '4bit', '8bit']:
|
||||
b = compress_layer_state_dict(a_state_dict, compression)
|
||||
|
||||
print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }")
|
||||
|
||||
aa = uncompress_layer_state_dict(b)
|
||||
|
||||
for k in aa.keys():
|
||||
|
||||
if compression is None:
|
||||
self.assertAlmostEqual(aa[k], a[k])
|
||||
else:
|
||||
RMSE_loss = torch.sqrt(loss_fn(aa[k], a[k]))
|
||||
self.assertLess(RMSE_loss.detach().numpy()[0], 0.5)
|
||||
Reference in New Issue
Block a user