mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
support Llama3.1 405B
This commit is contained in:
@@ -11,6 +11,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from transformers.quantizers import AutoHfQuantizer, HfQuantizer
|
||||
|
||||
from .profiler import LayeredProfiler
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -87,6 +89,8 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
self.total_disk_loading_time = None
|
||||
self.total_gpu_loading_time = None
|
||||
self.total_compression_overhead_time = None
|
||||
self._supports_cache_class = False
|
||||
self.hf_quantizer = None
|
||||
|
||||
if compression is not None:
|
||||
if not bitsandbytes_installed:
|
||||
@@ -183,9 +187,7 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
try:
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
|
||||
self.model.eval()
|
||||
self.model = BetterTransformer.transform(self.model) # enable flash attention
|
||||
self.model.tie_weights()
|
||||
except ValueError as ve:
|
||||
del self.model
|
||||
clean_memory()
|
||||
@@ -200,8 +202,6 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa", trust_remote_code=True)
|
||||
self.model.eval()
|
||||
self.model.tie_weights()
|
||||
print(f"attn imp: {type(self.model.model.layers[3].self_attn)}")
|
||||
|
||||
except TypeError as ve:
|
||||
@@ -214,9 +214,16 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
print(f"either BetterTransformer or attn_implementation='sdpa' is available, creating model directly")
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
|
||||
self.model.eval()
|
||||
self.model.tie_weights()
|
||||
|
||||
quantization_config = getattr(self.config, "quantization_config", None)
|
||||
|
||||
if quantization_config is not None:
|
||||
self.hf_quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=True)
|
||||
device_map = self.hf_quantizer.update_device_map(None)
|
||||
self.hf_quantizer.preprocess_model(model = self.model, device_map = device_map)
|
||||
|
||||
self.model.eval()
|
||||
self.model.tie_weights()
|
||||
|
||||
self.set_layers_from_layer_names()
|
||||
|
||||
@@ -288,11 +295,27 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
return state_dict
|
||||
|
||||
def move_layer_to_device(self, state_dict):
|
||||
layers = []
|
||||
for param_name, param in state_dict.items():
|
||||
#assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
|
||||
set_module_tensor_to_device(self.model, param_name, self.running_device, value=param,
|
||||
dtype=self.running_dtype,
|
||||
)
|
||||
if self.hf_quantizer is None:
|
||||
layers.append(layer_name)
|
||||
else:
|
||||
if '.weight' in param_name:
|
||||
layer_name = param_name[:param_name.index(".weight") + len(".weight")]
|
||||
if layer_name not in layers:
|
||||
layers.append(layer_name)
|
||||
|
||||
for param_name in layers:
|
||||
if (self.hf_quantizer is None or
|
||||
not self.hf_quantizer.check_quantized_param(self.model, param_value=None, param_name=param_name, state_dict={})
|
||||
):
|
||||
set_module_tensor_to_device(self.model, param_name, self.running_device, value=state_dict[param_name],
|
||||
dtype=self.running_dtype,
|
||||
)
|
||||
else:
|
||||
torch_dtype = self.hf_quantizer.update_torch_dtype(None)
|
||||
self.hf_quantizer.create_quantized_param(self.model, state_dict[param_name], param_name, self.running_device, state_dict)
|
||||
return layers
|
||||
|
||||
# make GenerationMixin happy
|
||||
def can_generate(self):
|
||||
@@ -420,7 +443,7 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
|
||||
|
||||
for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)),
|
||||
desc=f'running layers(self.running_device)',
|
||||
desc=f'running layers({self.running_device})',
|
||||
total=len(self.layers)):
|
||||
|
||||
if self.prefetching:
|
||||
@@ -438,7 +461,7 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
|
||||
if self.profiling_mode:
|
||||
t = time.time()
|
||||
self.move_layer_to_device(state_dict)
|
||||
moved_layers = self.move_layer_to_device(state_dict)
|
||||
if self.profiling_mode:
|
||||
elapsed_time = time.time() - t
|
||||
self.profiler.add_profiling_time('create_layer_from_state_dict', elapsed_time)
|
||||
@@ -462,7 +485,7 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
state_dict = self.load_layer_to_cpu(layer_name)
|
||||
if self.profiling_mode:
|
||||
t = time.time()
|
||||
self.move_layer_to_device(state_dict)
|
||||
moved_layers = self.move_layer_to_device(state_dict)
|
||||
if self.profiling_mode:
|
||||
elapsed_time = time.time() - t
|
||||
self.profiler.add_profiling_time('create_layer_from_safe_tensor', elapsed_time)
|
||||
@@ -561,6 +584,13 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
all_hidden_states += (torch.cat(batch, 0),)
|
||||
|
||||
# Remove previous layer from memory (including buffers)
|
||||
|
||||
if self.hf_quantizer is not None:
|
||||
for param_name in moved_layers:#param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(self.model, param_name,'meta')
|
||||
else:
|
||||
layer.to("meta")
|
||||
|
||||
layer.to("meta")
|
||||
clean_memory() # proposed by CPMP
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class LayeredProfiler:
|
||||
def __init__(self):
|
||||
def __init__(self, print_memory=False):
|
||||
self.profiling_time_dict = {}
|
||||
self.print_memory = print_memory
|
||||
self.min_free_mem = 1024*1024*1024*1024
|
||||
|
||||
|
||||
def add_profiling_time(self, item, time):
|
||||
@@ -14,6 +16,10 @@ class LayeredProfiler:
|
||||
|
||||
self.profiling_time_dict[item].append(time)
|
||||
|
||||
if self.print_memory:
|
||||
free_mem = torch.cuda.mem_get_info()[0]
|
||||
self.min_free_mem = min(self.min_free_mem, free_mem)
|
||||
print(f"free vmem @{item}: {free_mem/1024/1024/1024:.02f}GB, min free: {self.min_free_mem/1024/1024/1024:.02f}GB")
|
||||
|
||||
def clear_profiling_time(self):
|
||||
for item in self.profiling_time_dict.keys():
|
||||
|
||||
Reference in New Issue
Block a user