From 4159a535ad5b54ed70c65aa78f2bd9a3e03903fd Mon Sep 17 00:00:00 2001 From: Yu Li Date: Mon, 18 Dec 2023 23:05:26 -0600 Subject: [PATCH] add prefetching --- air_llm/README.md | 3 + air_llm/airllm/airllm.py | 144 +++++++++++++++++++++++++------------ air_llm/airllm/profiler.py | 25 +++++++ air_llm/airllm/utils.py | 8 ++- 4 files changed, 132 insertions(+), 48 deletions(-) create mode 100644 air_llm/airllm/profiler.py diff --git a/air_llm/README.md b/air_llm/README.md index 61a4480..8a27b7c 100644 --- a/air_llm/README.md +++ b/air_llm/README.md @@ -7,6 +7,8 @@ AirLLM优化inference内存,4GB单卡GPU可以运行70B大语言模型推理 ## Updates +[2023/12/18] added prefetching to overlap the model loading and compute. 10% speed improvement. + [2023/12/03] added support of **ChatGLM**, **QWen**, **Baichuan**, **Mistral**, **InternLM**! 支持ChatGLM, QWEN, Baichuan, Mistral, InternLM! @@ -128,6 +130,7 @@ When initialize the model, we support the following configurations: * **profiling_mode**: supported options: True to output time consumptions or by default False * **layer_shards_saving_path**: optionally another path to save the splitted model * **hf_token**: huggingface token can be provided here if downloading gated models like: *meta-llama/Llama-2-7b-hf* +* **prefetching**: prefetching to overlap the model loading and compute. By default turned on. For now only AirLLMLlama2 supports this. ### 5. Supported Models diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 2205042..abc1c61 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -8,15 +8,17 @@ from tqdm import tqdm from pathlib import Path from glob import glob import time +from concurrent.futures import ThreadPoolExecutor import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationMixin, LlamaForCausalLM, GenerationConfig from transformers.modeling_outputs import CausalLMOutputWithPast from accelerate import init_empty_weights + from accelerate.utils.modeling import set_module_tensor_to_device -from safetensors.torch import load_file, save_file +from .profiler import LayeredProfiler + from optimum.bettertransformer import BetterTransformer -import huggingface_hub from .utils import save_quant_state_to_dict, NotEnoughSpaceException, clean_memory, uncompress_layer_state_dict, load_layer, \ check_space, compress_layer_state_dict, split_and_save_layers, find_or_create_local_splitted_path @@ -41,16 +43,13 @@ except ImportError: -total_disk_loading_time = None -total_gpu_loading_time = None -total_compression_overhead_time = None class AirLLMLlama2(GenerationMixin): def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512, layer_shards_saving_path=None, profiling_mode=False, compression=None, - hf_token=None): + hf_token=None, prefetching=True): """ Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage. During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer. @@ -78,6 +77,11 @@ class AirLLMLlama2(GenerationMixin): self.profiling_mode = profiling_mode + self.profiler = LayeredProfiler() + + self.total_disk_loading_time = None + self.total_gpu_loading_time = None + self.total_compression_overhead_time = None if compression is not None: if not bitsandbytes_installed: @@ -120,6 +124,13 @@ class AirLLMLlama2(GenerationMixin): self.main_input_name = "input_ids" + # model weights prefetch cuda stream + self.prefetching = prefetching + if prefetching: + self.stream = torch.cuda.Stream() + else: + self.stream = None + def init_model(self): # Load meta model (no memory used) @@ -150,26 +161,44 @@ class AirLLMLlama2(GenerationMixin): set_module_tensor_to_device(self.model, buffer_name, self.running_device, value=buffer, dtype=self.running_dtype) - def load_layer_to_cpu(self, layer_name, profiling=False): + def load_layer_to_cpu(self, layer_name): t = time.time() - load_layer_output = load_layer(self.checkpoint_path, layer_name, profiling) + + load_layer_output = load_layer(self.checkpoint_path, layer_name, self.profiling_mode) elapsed_time = time.time() - t - if profiling: + if self.profiling_mode: state_dict, compression_time = load_layer_output disk_loading_time = elapsed_time - compression_time - return state_dict, disk_loading_time, compression_time + + self.profiler.add_profiling_time('load_safe_tensor', disk_loading_time) + + self.profiler.add_profiling_time('compression_time', compression_time) else: state_dict = load_layer_output - return state_dict + # pin memory: + + t = time.time() + + for k in state_dict.keys(): + #state_dict[k] = state_dict[k].to(torch.float) + state_dict[k].pin_memory() + + + elapsed_time = time.time() - t + if self.profiling_mode: + self.profiler.add_profiling_time('pin_memory_time', elapsed_time) + + return state_dict def move_layer_to_device(self, state_dict): for param_name, param in state_dict.items(): assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)" set_module_tensor_to_device(self.model, param_name, self.running_device, value=param, - dtype=self.running_dtype) + dtype=self.running_dtype, + ) # make GenerationMixin happy def can_generate(self): @@ -217,6 +246,8 @@ class AirLLMLlama2(GenerationMixin): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + + def forward( self, input_ids: torch.LongTensor = None, @@ -231,16 +262,13 @@ class AirLLMLlama2(GenerationMixin): return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time - if cache_utils_installed: # we don't support kv cache for new version yet use_cache = False if self.profiling_mode: - total_disk_loading_time = [] - total_gpu_loading_time = [] - total_compression_overhead_time = [] + self.profiler.clear_profiling_time() + forward_start = time.process_time() forward_start_wall = time.time() @@ -266,28 +294,61 @@ class AirLLMLlama2(GenerationMixin): all_hidden_states = [] * len(self.layers) if output_hidden_states else None all_self_attns = [] * len(self.layers) if output_attentions else None - with torch.inference_mode(): + with torch.inference_mode(), ThreadPoolExecutor() as executor: + + # Load first layer + if self.prefetching: + #with torch.cuda.stream(self.stream): + #state_dict = self.load_layer_to_cpu(self.layer_names[0]) + future = executor.submit(self.load_layer_to_cpu, self.layer_names[0]) + for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.running_device, total=len(self.layers)): - load_layer_to_cpu_output = self.load_layer_to_cpu(layer_name, self.profiling_mode) - # profile - if self.profiling_mode: - state_dict, disk_loading_time, compression_time = load_layer_to_cpu_output - total_disk_loading_time.append(disk_loading_time) - total_compression_overhead_time.append(compression_time) - else: - state_dict = load_layer_to_cpu_output + if self.prefetching: + if self.profiling_mode: + t = time.time() + # Load current layer and prepare next layer + state_dict = future.result() + #torch.cuda.current_stream().wait_stream(self.stream) + if self.profiling_mode: + elapsed_time = time.time() - t + self.profiler.add_profiling_time('load_safe_tensor_cpu_wait', elapsed_time) - t = time.time() - self.move_layer_to_device(state_dict) - if self.profiling_mode: - torch.cuda.synchronize() - elapsed_time = time.time() - t - # profile - if self.profiling_mode: - total_gpu_loading_time.append(elapsed_time) + #for param_name, param in state_dict.items(): + # state_dict[param_name] = param.to('cuda', non_blocking=True) + + if self.profiling_mode: + t = time.time() + self.move_layer_to_device(state_dict) + if self.profiling_mode: + elapsed_time = time.time() - t + self.profiler.add_profiling_time('create_layer_from_state_dict', elapsed_time) + + # kick off next layer loading + + if (i + 1) < len(self.layer_names): + #with torch.cuda.stream(self.stream): + #state_dict = self.load_layer_to_cpu(self.layer_names[i + 1]) + if self.profiling_mode: + t = time.time() + future = executor.submit(self.load_layer_to_cpu, self.layer_names[i+1]) + #for param_name, param in state_dict.items(): + # state_dict[param_name] = param.to('cuda', non_blocking=True) + + if self.profiling_mode: + elapsed_time = time.time() - t + self.profiler.add_profiling_time('kick_off_load_cpu', elapsed_time) + + else: + state_dict = self.load_layer_to_cpu(layer_name) + if self.profiling_mode: + t = time.time() + self.move_layer_to_device(state_dict) + if self.profiling_mode: + elapsed_time = time.time() - t + self.profiler.add_profiling_time('create_layer_from_safe_tensor', elapsed_time) # Run layer @@ -390,22 +451,13 @@ class AirLLMLlama2(GenerationMixin): if self.profiling_mode: forward_elapsed_time = time.process_time() - forward_start forward_elapsed_time_wall = time.time() - forward_start_wall - if self.compression: - print(f"total disk loading time: {sum(total_disk_loading_time):.04f}") - print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}") - print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}") - else: - # loading is async/lazy, so can't really distinguish them... - print(f"total disk+gpu loading time: {sum(total_disk_loading_time) + sum(total_gpu_loading_time):.04f}") - #print(f"total disk loading time: {sum(total_disk_loading_time):.04f}") - #print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}") + self.profiler.print_profiling_time() + print(f"total infer process time(including all above plus gpu compute): {forward_elapsed_time:.04f}") print(f"total infer wall time(including all above plus gpu compute): {forward_elapsed_time_wall:.04f}") - total_disk_loading_time = [] - total_gpu_loading_time = [] - total_compression_overhead_time = [] + self.profiler.clear_profiling_time() return CausalLMOutputWithPast( diff --git a/air_llm/airllm/profiler.py b/air_llm/airllm/profiler.py new file mode 100644 index 0000000..e5a9787 --- /dev/null +++ b/air_llm/airllm/profiler.py @@ -0,0 +1,25 @@ + + + + +class LayeredProfiler: + def __init__(self): + self.profiling_time_dict = {} + + + def add_profiling_time(self, item, time): + + if not item in self.profiling_time_dict: + self.profiling_time_dict[item] = [] + + self.profiling_time_dict[item].append(time) + + + def clear_profiling_time(self): + for item in self.profiling_time_dict.keys(): + self.profiling_time_dict[item] = [] + + def print_profiling_time(self): + for item in self.profiling_time_dict.keys(): + print(f"total time for {item}: {sum(self.profiling_time_dict[item])}") + diff --git a/air_llm/airllm/utils.py b/air_llm/airllm/utils.py index bdc6d28..0440b0f 100644 --- a/air_llm/airllm/utils.py +++ b/air_llm/airllm/utils.py @@ -8,9 +8,14 @@ from pathlib import Path from glob import glob import time +from collections import OrderedDict, defaultdict +from typing import Dict, List, Optional, Tuple, Union + import torch +import torch.nn as nn from safetensors.torch import load_file, save_file + try: import bitsandbytes as bnb @@ -107,7 +112,7 @@ def load_layer(local_path, layer_name, profiling=False): to_return = uncompress_layer_state_dict(layer_state_dict) - clean_memory() + #clean_memory() if profiling: elapsed_time = time.process_time() - t @@ -338,4 +343,3 @@ def find_or_create_local_splitted_path(model_local_path_or_repo_id, layer_shards # if splitted_model subdir exists under cache use it, otherwise split and save return Path(hf_cache_path), split_and_save_layers(hf_cache_path, layer_shards_saving_path, compression=compression, layer_names=layer_names) -