support Llama3.1 405B

This commit is contained in:
Yu Li
2024-07-30 22:30:05 -05:00
parent 65900351e1
commit 9ace0fc57b
2 changed files with 51 additions and 15 deletions

View File

@@ -11,6 +11,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from transformers.quantizers import AutoHfQuantizer, HfQuantizer
from .profiler import LayeredProfiler
from optimum.bettertransformer import BetterTransformer
@@ -87,6 +89,8 @@ class AirLLMBaseModel(GenerationMixin):
self.total_disk_loading_time = None
self.total_gpu_loading_time = None
self.total_compression_overhead_time = None
self._supports_cache_class = False
self.hf_quantizer = None
if compression is not None:
if not bitsandbytes_installed:
@@ -183,9 +187,7 @@ class AirLLMBaseModel(GenerationMixin):
try:
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
self.model.eval()
self.model = BetterTransformer.transform(self.model) # enable flash attention
self.model.tie_weights()
except ValueError as ve:
del self.model
clean_memory()
@@ -200,8 +202,6 @@ class AirLLMBaseModel(GenerationMixin):
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa", trust_remote_code=True)
self.model.eval()
self.model.tie_weights()
print(f"attn imp: {type(self.model.model.layers[3].self_attn)}")
except TypeError as ve:
@@ -214,9 +214,16 @@ class AirLLMBaseModel(GenerationMixin):
print(f"either BetterTransformer or attn_implementation='sdpa' is available, creating model directly")
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
self.model.eval()
self.model.tie_weights()
quantization_config = getattr(self.config, "quantization_config", None)
if quantization_config is not None:
self.hf_quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=True)
device_map = self.hf_quantizer.update_device_map(None)
self.hf_quantizer.preprocess_model(model = self.model, device_map = device_map)
self.model.eval()
self.model.tie_weights()
self.set_layers_from_layer_names()
@@ -288,11 +295,27 @@ class AirLLMBaseModel(GenerationMixin):
return state_dict
def move_layer_to_device(self, state_dict):
layers = []
for param_name, param in state_dict.items():
#assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
set_module_tensor_to_device(self.model, param_name, self.running_device, value=param,
dtype=self.running_dtype,
)
if self.hf_quantizer is None:
layers.append(layer_name)
else:
if '.weight' in param_name:
layer_name = param_name[:param_name.index(".weight") + len(".weight")]
if layer_name not in layers:
layers.append(layer_name)
for param_name in layers:
if (self.hf_quantizer is None or
not self.hf_quantizer.check_quantized_param(self.model, param_value=None, param_name=param_name, state_dict={})
):
set_module_tensor_to_device(self.model, param_name, self.running_device, value=state_dict[param_name],
dtype=self.running_dtype,
)
else:
torch_dtype = self.hf_quantizer.update_torch_dtype(None)
self.hf_quantizer.create_quantized_param(self.model, state_dict[param_name], param_name, self.running_device, state_dict)
return layers
# make GenerationMixin happy
def can_generate(self):
@@ -420,7 +443,7 @@ class AirLLMBaseModel(GenerationMixin):
for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)),
desc=f'running layers(self.running_device)',
desc=f'running layers({self.running_device})',
total=len(self.layers)):
if self.prefetching:
@@ -438,7 +461,7 @@ class AirLLMBaseModel(GenerationMixin):
if self.profiling_mode:
t = time.time()
self.move_layer_to_device(state_dict)
moved_layers = self.move_layer_to_device(state_dict)
if self.profiling_mode:
elapsed_time = time.time() - t
self.profiler.add_profiling_time('create_layer_from_state_dict', elapsed_time)
@@ -462,7 +485,7 @@ class AirLLMBaseModel(GenerationMixin):
state_dict = self.load_layer_to_cpu(layer_name)
if self.profiling_mode:
t = time.time()
self.move_layer_to_device(state_dict)
moved_layers = self.move_layer_to_device(state_dict)
if self.profiling_mode:
elapsed_time = time.time() - t
self.profiler.add_profiling_time('create_layer_from_safe_tensor', elapsed_time)
@@ -561,6 +584,13 @@ class AirLLMBaseModel(GenerationMixin):
all_hidden_states += (torch.cat(batch, 0),)
# Remove previous layer from memory (including buffers)
if self.hf_quantizer is not None:
for param_name in moved_layers:#param_name, param in state_dict.items():
set_module_tensor_to_device(self.model, param_name,'meta')
else:
layer.to("meta")
layer.to("meta")
clean_memory() # proposed by CPMP

View File

@@ -1,10 +1,12 @@
import torch
class LayeredProfiler:
def __init__(self):
def __init__(self, print_memory=False):
self.profiling_time_dict = {}
self.print_memory = print_memory
self.min_free_mem = 1024*1024*1024*1024
def add_profiling_time(self, item, time):
@@ -14,6 +16,10 @@ class LayeredProfiler:
self.profiling_time_dict[item].append(time)
if self.print_memory:
free_mem = torch.cuda.mem_get_info()[0]
self.min_free_mem = min(self.min_free_mem, free_mem)
print(f"free vmem @{item}: {free_mem/1024/1024/1024:.02f}GB, min free: {self.min_free_mem/1024/1024/1024:.02f}GB")
def clear_profiling_time(self):
for item in self.profiling_time_dict.keys():