From c2d13c1063cdeec9b32c07154802203d83baec53 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 16:27:49 -0600 Subject: [PATCH] profiling and compression --- air_llm/airllm/airllm.py | 157 ++++++++++++++++++++++++++++-- air_llm/setup.py | 2 + air_llm/tests/test_compression.py | 40 ++++++++ 3 files changed, 192 insertions(+), 7 deletions(-) create mode 100644 air_llm/tests/test_compression.py diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 0e8bf02..94ebf15 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -7,6 +7,7 @@ import shutil from tqdm import tqdm from pathlib import Path from glob import glob +import time import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationMixin, LlamaForCausalLM, GenerationConfig @@ -17,6 +18,51 @@ from safetensors.torch import load_file, save_file from optimum.bettertransformer import BetterTransformer import huggingface_hub + +try: + import bitsandbytes as bnb + + bitsandbytes_installed = True + print('>>>> bitsandbytes installed') +except ImportError: + bitsandbytes_installed = False + +total_disk_loading_time = None +total_gpu_loading_time = None + + +# replacement for bnb quantstat.as_dict(True), until the bug is fixed.... +def save_quant_state_to_dict(self, packed=True): + """ + returns dict of tensors and strings to use in serialization via _save_to_state_dict() + param: packed -- returns dict[str, torch.Tensor] for state_dict + """ + qs_dict = { + 'quant_type': self.quant_type, + 'absmax': self.absmax, + 'blocksize': self.blocksize, + 'quant_map': self.code, + 'dtype': str(self.dtype).strip('torch.'), + 'shape': tuple(self.shape), + } + if self.nested: + qs_dict.update({ + 'nested_absmax': self.state2.absmax, + 'nested_blocksize': self.state2.blocksize, + 'nested_quant_map': self.state2.code, + 'nested_dtype': str(self.state2.dtype).strip('torch.'), + 'nested_offset': self.offset.item(), + }) + if not packed: + return qs_dict + + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = bnb.utils.pack_dict_to_tensor(non_tensor_dict) + return qs_packed_dict + + + class NotEnoughSpaceException(Exception): pass @@ -30,9 +76,37 @@ def clean_memory(): pass torch.cuda.empty_cache() + +def uncompress_layer_state_dict(layer_state_dict): + uncompressed_layer_state_dict = None + if any(['4bit' in k for k in layer_state_dict.keys()]): + uncompressed_layer_state_dict = {} + for k, v in layer_state_dict.items(): + if '4bit' not in k: + quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk} + quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cpu") + + dqv = bnb.functional.dequantize_nf4(v, quant_state) + uncompressed_layer_state_dict[k] = dqv + del layer_state_dict + elif any(['8bit' in k for k in layer_state_dict.keys()]): + uncompressed_layer_state_dict = {} + for k, v in layer_state_dict.items(): + if '8bit' not in k: + + absmax = layer_state_dict[k + ".8bit.absmax"] + code = layer_state_dict[k + ".8bit.code"] + + dqv = bnb.functional.dequantize(v, (absmax, code)) + uncompressed_layer_state_dict[k] = dqv + del layer_state_dict + + return layer_state_dict if uncompressed_layer_state_dict is None else uncompressed_layer_state_dict + def load_layer(local_path, layer_name): layer_state_dict = load_file(Path(local_path) / (layer_name + ".safetensors"), device="cpu") - return layer_state_dict + return uncompress_layer_state_dict(layer_state_dict) + def check_space(checkpoint_path, layer_shards_saving_path=None): @@ -46,12 +120,35 @@ def check_space(checkpoint_path, layer_shards_saving_path=None): raise NotEnoughSpaceException(f"Not enough space. Free space under {checkpoint_path if layer_shards_saving_path is None else layer_shards_saving_path}:" \ f" {free / 1024 / 1024 / 1024:.02f}GB. Model total size: {total_shard_files_size_bytes / 1024 / 1024 / 1024:.02f}GB.") +def compress_layer_state_dict(layer_state_dict, compression=None): + compressed_layer_state_dict = None + if compression == '4bit': + compressed_layer_state_dict = {} + for k, v in layer_state_dict.items(): + v_quant, quant_state = bnb.functional.quantize_nf4(v, blocksize=64) + compressed_layer_state_dict[k] = v_quant + for quant_state_k, quant_state_v in save_quant_state_to_dict(quant_state).items(): + compressed_layer_state_dict[k + ".4bit." + quant_state_k] = quant_state_v + elif compression == '8bit': + compressed_layer_state_dict = {} + for k, v in layer_state_dict.items(): + v_quant, (absmax, code) = bnb.functional.quantize(v) + compressed_layer_state_dict[k] = v_quant + compressed_layer_state_dict[k + ".8bit.absmax"] = absmax + compressed_layer_state_dict[k + ".8bit.code"] = code -def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitted_model_dir_name='splitted_model'): + return compressed_layer_state_dict if compressed_layer_state_dict is not None else layer_state_dict + + +def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitted_model_dir_name='splitted_model', compression=None): """ Save the all layers of a model sharded checkpoint using safetensors. """ + if compression is not None: + assert bitsandbytes_installed, f"when using compression bitsandbytes has to be installed." + splitted_model_dir_name = splitted_model_dir_name + "." + compression + checkpoint_path = Path(checkpoint_path) @@ -95,6 +192,9 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt # Get layer state dict layer_state_dict = dict([(k, v) for k, v in state_dict.items() if k.startswith(layer)]) + layer_state_dict = compress_layer_state_dict(layer_state_dict, compression) + + # Save layer state dict as using safetensors save_file(layer_state_dict, saving_path / (layer + 'safetensors')) @@ -108,7 +208,7 @@ def split_and_save_layers(checkpoint_path, layer_shards_saving_path=None, splitt return str(saving_path) -def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None): +def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path=None, compression=None): """ find the model's local cache path, download the cache if not exists, then split and save the model. @@ -125,12 +225,14 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards local model path saved_layer_shards_path : str the path saved layer shards + compression: str, optinal + setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss. """ # try local model path, if the model exist split and save there if os.path.exists(model_local_path_or_repo_id): if os.path.exists(Path(model_local_path_or_repo_id) / 'pytorch_model.bin.index.json'): - return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path) + return Path(model_local_path_or_repo_id), split_and_save_layers(model_local_path_or_repo_id, layer_shards_saving_path, compression=compression) else: print( f"Found local directory in {model_local_path_or_repo_id}, but didn't find downloaded model. Try using {model_local_path_or_repo_id} as a HF repo...") @@ -141,13 +243,13 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards hf_cache_path) / 'pytorch_model.bin.index.json'), f"{hf_cache_path}/pytorch_model.bin.index.json should exists." # if splitted_model subdir exists under cache use it, otherwise split and save - return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path) + return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path, compression=compression) class AirLLMLlama2(GenerationMixin): def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512, - layer_shards_saving_path=None): + layer_shards_saving_path=None, profiling_mode=False, compression=None): """ Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage. During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer. @@ -165,10 +267,27 @@ class AirLLMLlama2(GenerationMixin): max seq lenght, by default 512 layer_shards_saving_path : str, optional optional path to save layered shards model file, by default just save to the local cache of model, subdir named splitted_model will be saved + profiling_mode : book, optional + if to profile the model loading time, default to False + compression: str, optinal + setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss. """ + + if profiling_mode: + self.profiling_mode = profiling_mode + + if compression is not None: + if not bitsandbytes_installed: + raise ImportError('WARNING: bitsandbytes not found. Compression needs bitsandbytes. To use compression, please install bitsandbytes: `pip install bitsandbytes`') + + + self.compression = compression + # Save parameters - self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards_saving_path) + self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id, + layer_shards_saving_path, + compression=compression) self.running_device = device self.device = torch.device(self.running_device) self.running_dtype = dtype @@ -277,6 +396,12 @@ class AirLLMLlama2(GenerationMixin): return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: + global total_disk_loading_time, total_gpu_loading_time + + if self.profiling_mode: + total_disk_loading_time = [] + total_gpu_loading_time = [] + # Reboot the model to make sure buffers are loaded and memory is clean del self.model clean_memory() @@ -304,8 +429,19 @@ class AirLLMLlama2(GenerationMixin): for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.running_device, total=len(self.layers)): + t = time.process_time() state_dict = self.load_layer_to_cpu(layer_name) + elapsed_time = time.process_time() - t + # profile + if self.profiling_mode: + total_disk_loading_time.append(elapsed_time) + + t = time.process_time() self.move_layer_to_device(state_dict) + elapsed_time = time.process_time() - t + # profile + if self.profiling_mode: + total_gpu_loading_time.append(elapsed_time) # Run layer @@ -403,6 +539,13 @@ class AirLLMLlama2(GenerationMixin): tuple(kv_cache_list) if kv_cache_list is not None else None, tuple(all_hidden_states) if all_hidden_states is not None else None, tuple(all_self_attns) if all_self_attns is not None else None] if v is not None) + if self.profiling_mode: + print(f"total disk loading time: {sum(total_disk_loading_time):.04f}") + print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}") + + total_disk_loading_time = [] + total_gpu_loading_time = [] + return CausalLMOutputWithPast( loss=None, diff --git a/air_llm/setup.py b/air_llm/setup.py index cd6c7eb..9f39a4d 100644 --- a/air_llm/setup.py +++ b/air_llm/setup.py @@ -21,6 +21,8 @@ setuptools.setup( 'safetensors', 'optimum', 'huggingface_hub' + 'scipy', + #'bitsandbytes' set it to optional to support fallback when not installable ], classifiers=[ "Programming Language :: Python :: 3", diff --git a/air_llm/tests/test_compression.py b/air_llm/tests/test_compression.py new file mode 100644 index 0000000..bb8479a --- /dev/null +++ b/air_llm/tests/test_compression.py @@ -0,0 +1,40 @@ +import sys +import unittest + +import torch +sys.path.insert(0, '../airllm') + +from airllm import compress_layer_state_dict, uncompress_layer_state_dict + + + + +class TestCompression(unittest.TestCase): + def setUp(self): + pass + def tearDown(self): + pass + + def test_should_compress_uncompress(self): + torch.manual_seed(0) + a0 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda() + a1 = torch.normal(0, 1, (32, 128), dtype=torch.float16).cuda() + + a_state_dict = {'a0':a0, 'a1':a1} + + loss_fn = torch.nn.MSELoss() + + for compression in [None, '4bit', '8bit']: + b = compress_layer_state_dict(a_state_dict, compression) + + print(f"for compression {compression}, compressed to: { {k:v.shape for k,v in b.items()} }") + + aa = uncompress_layer_state_dict(b) + + for k in aa.keys(): + + if compression is None: + self.assertAlmostEqual(aa[k], a[k]) + else: + RMSE_loss = torch.sqrt(loss_fn(aa[k], a[k])) + self.assertLess(RMSE_loss.detach().numpy()[0], 0.5) \ No newline at end of file