diff --git a/air_llm/airllm/airllm_base.py b/air_llm/airllm/airllm_base.py index 89d02f7..6dcbbb0 100644 --- a/air_llm/airllm/airllm_base.py +++ b/air_llm/airllm/airllm_base.py @@ -11,6 +11,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device +from transformers.quantizers import AutoHfQuantizer, HfQuantizer + from .profiler import LayeredProfiler from optimum.bettertransformer import BetterTransformer @@ -87,6 +89,8 @@ class AirLLMBaseModel(GenerationMixin): self.total_disk_loading_time = None self.total_gpu_loading_time = None self.total_compression_overhead_time = None + self._supports_cache_class = False + self.hf_quantizer = None if compression is not None: if not bitsandbytes_installed: @@ -183,9 +187,7 @@ class AirLLMBaseModel(GenerationMixin): try: with init_empty_weights(): self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) - self.model.eval() self.model = BetterTransformer.transform(self.model) # enable flash attention - self.model.tie_weights() except ValueError as ve: del self.model clean_memory() @@ -200,8 +202,6 @@ class AirLLMBaseModel(GenerationMixin): with init_empty_weights(): self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa", trust_remote_code=True) - self.model.eval() - self.model.tie_weights() print(f"attn imp: {type(self.model.model.layers[3].self_attn)}") except TypeError as ve: @@ -214,9 +214,16 @@ class AirLLMBaseModel(GenerationMixin): print(f"either BetterTransformer or attn_implementation='sdpa' is available, creating model directly") with init_empty_weights(): self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) - self.model.eval() - self.model.tie_weights() + quantization_config = getattr(self.config, "quantization_config", None) + + if quantization_config is not None: + self.hf_quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=True) + device_map = self.hf_quantizer.update_device_map(None) + self.hf_quantizer.preprocess_model(model = self.model, device_map = device_map) + + self.model.eval() + self.model.tie_weights() self.set_layers_from_layer_names() @@ -288,11 +295,27 @@ class AirLLMBaseModel(GenerationMixin): return state_dict def move_layer_to_device(self, state_dict): + layers = [] for param_name, param in state_dict.items(): - #assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)" - set_module_tensor_to_device(self.model, param_name, self.running_device, value=param, - dtype=self.running_dtype, - ) + if self.hf_quantizer is None: + layers.append(layer_name) + else: + if '.weight' in param_name: + layer_name = param_name[:param_name.index(".weight") + len(".weight")] + if layer_name not in layers: + layers.append(layer_name) + + for param_name in layers: + if (self.hf_quantizer is None or + not self.hf_quantizer.check_quantized_param(self.model, param_value=None, param_name=param_name, state_dict={}) + ): + set_module_tensor_to_device(self.model, param_name, self.running_device, value=state_dict[param_name], + dtype=self.running_dtype, + ) + else: + torch_dtype = self.hf_quantizer.update_torch_dtype(None) + self.hf_quantizer.create_quantized_param(self.model, state_dict[param_name], param_name, self.running_device, state_dict) + return layers # make GenerationMixin happy def can_generate(self): @@ -420,7 +443,7 @@ class AirLLMBaseModel(GenerationMixin): for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), - desc=f'running layers(self.running_device)', + desc=f'running layers({self.running_device})', total=len(self.layers)): if self.prefetching: @@ -438,7 +461,7 @@ class AirLLMBaseModel(GenerationMixin): if self.profiling_mode: t = time.time() - self.move_layer_to_device(state_dict) + moved_layers = self.move_layer_to_device(state_dict) if self.profiling_mode: elapsed_time = time.time() - t self.profiler.add_profiling_time('create_layer_from_state_dict', elapsed_time) @@ -462,7 +485,7 @@ class AirLLMBaseModel(GenerationMixin): state_dict = self.load_layer_to_cpu(layer_name) if self.profiling_mode: t = time.time() - self.move_layer_to_device(state_dict) + moved_layers = self.move_layer_to_device(state_dict) if self.profiling_mode: elapsed_time = time.time() - t self.profiler.add_profiling_time('create_layer_from_safe_tensor', elapsed_time) @@ -561,6 +584,13 @@ class AirLLMBaseModel(GenerationMixin): all_hidden_states += (torch.cat(batch, 0),) # Remove previous layer from memory (including buffers) + + if self.hf_quantizer is not None: + for param_name in moved_layers:#param_name, param in state_dict.items(): + set_module_tensor_to_device(self.model, param_name,'meta') + else: + layer.to("meta") + layer.to("meta") clean_memory() # proposed by CPMP diff --git a/air_llm/airllm/profiler.py b/air_llm/airllm/profiler.py index e5a9787..d457605 100644 --- a/air_llm/airllm/profiler.py +++ b/air_llm/airllm/profiler.py @@ -1,10 +1,12 @@ - +import torch class LayeredProfiler: - def __init__(self): + def __init__(self, print_memory=False): self.profiling_time_dict = {} + self.print_memory = print_memory + self.min_free_mem = 1024*1024*1024*1024 def add_profiling_time(self, item, time): @@ -14,6 +16,10 @@ class LayeredProfiler: self.profiling_time_dict[item].append(time) + if self.print_memory: + free_mem = torch.cuda.mem_get_info()[0] + self.min_free_mem = min(self.min_free_mem, free_mem) + print(f"free vmem @{item}: {free_mem/1024/1024/1024:.02f}GB, min free: {self.min_free_mem/1024/1024/1024:.02f}GB") def clear_profiling_time(self): for item in self.profiling_time_dict.keys():