fix new_seq

This commit is contained in:
Yu Li
2023-12-17 17:12:04 -06:00
parent 4c1c1cf7e8
commit 97e96d8be5

View File

@@ -29,6 +29,18 @@ try:
except ImportError: except ImportError:
bitsandbytes_installed = False bitsandbytes_installed = False
try:
from transformers.cache_utils import Cache, DynamicCache
cache_utils_installed = True
print('>>>> cache_utils installed')
except ImportError:
cache_utils_installed = False
total_disk_loading_time = None total_disk_loading_time = None
total_gpu_loading_time = None total_gpu_loading_time = None
total_compression_overhead_time = None total_compression_overhead_time = None
@@ -111,14 +123,24 @@ class AirLLMLlama2(GenerationMixin):
def init_model(self): def init_model(self):
# Load meta model (no memory used) # Load meta model (no memory used)
with init_empty_weights(): try:
self.model = AutoModelForCausalLM.from_config(self.config) with init_empty_weights():
self.model.eval() self.model = AutoModelForCausalLM.from_config(self.config)
try: self.model.eval()
self.model = BetterTransformer.transform(self.model) # enable flash attention self.model = BetterTransformer.transform(self.model) # enable flash attention
except ValueError as ve: self.model.tie_weights()
print(f"new version of transfomer, no need to use BetterTransformer...") except ValueError as ve:
self.model.tie_weights() del self.model
clean_memory()
print(f"new version of transfomer, no need to use BetterTransformer, setting attn impl to sdpa...")
self.config.attn_implementation = "sdpa"
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa")
self.model.eval()
self.model.tie_weights()
print(f"attn imp: {type(self.model.model.layers[3].self_attn)}")
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
self.model.lm_head] self.model.lm_head]
@@ -211,6 +233,10 @@ class AirLLMLlama2(GenerationMixin):
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
if cache_utils_installed:
# we don't support kv cache for new version yet
use_cache = False
if self.profiling_mode: if self.profiling_mode:
total_disk_loading_time = [] total_disk_loading_time = []
total_gpu_loading_time = [] total_gpu_loading_time = []
@@ -318,17 +344,13 @@ class AirLLMLlama2(GenerationMixin):
new_seq = layer(seq, new_seq = layer(seq,
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0] attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0]
else: else:
layer_out = layer(seq, layer_out = layer(seq, use_cache=True,
use_cache=True, attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])
attention_mask=attention_mask[:, :, -len_seq:,
-len_seq:])
# TODO: adopt Cache mechanism in 4.36
if layer_out[1] is not None:
#print(f"layer:{type(layer)} layer_out:{layer_out}") # TODO: adopt Cache mechanism in 4.36
new_seq, (k_cache, v_cache) = layer_out new_seq, (k_cache, v_cache) = layer_out
kv_cache_list[i][0].append(k_cache) kv_cache_list[i][0].append(k_cache)
kv_cache_list[i][1].append(v_cache) kv_cache_list[i][1].append(v_cache)
# print(f"k_cache size: {k_cache.shape}") # print(f"k_cache size: {k_cache.shape}")
# print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}") # print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")