fix new_seq

This commit is contained in:
Yu Li
2023-12-17 17:12:04 -06:00
parent 4c1c1cf7e8
commit 97e96d8be5

View File

@@ -29,6 +29,18 @@ try:
except ImportError:
bitsandbytes_installed = False
try:
from transformers.cache_utils import Cache, DynamicCache
cache_utils_installed = True
print('>>>> cache_utils installed')
except ImportError:
cache_utils_installed = False
total_disk_loading_time = None
total_gpu_loading_time = None
total_compression_overhead_time = None
@@ -111,14 +123,24 @@ class AirLLMLlama2(GenerationMixin):
def init_model(self):
# Load meta model (no memory used)
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config)
self.model.eval()
try:
try:
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config)
self.model.eval()
self.model = BetterTransformer.transform(self.model) # enable flash attention
except ValueError as ve:
print(f"new version of transfomer, no need to use BetterTransformer...")
self.model.tie_weights()
self.model.tie_weights()
except ValueError as ve:
del self.model
clean_memory()
print(f"new version of transfomer, no need to use BetterTransformer, setting attn impl to sdpa...")
self.config.attn_implementation = "sdpa"
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa")
self.model.eval()
self.model.tie_weights()
print(f"attn imp: {type(self.model.model.layers[3].self_attn)}")
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
self.model.lm_head]
@@ -211,6 +233,10 @@ class AirLLMLlama2(GenerationMixin):
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
if cache_utils_installed:
# we don't support kv cache for new version yet
use_cache = False
if self.profiling_mode:
total_disk_loading_time = []
total_gpu_loading_time = []
@@ -318,17 +344,13 @@ class AirLLMLlama2(GenerationMixin):
new_seq = layer(seq,
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0]
else:
layer_out = layer(seq,
use_cache=True,
attention_mask=attention_mask[:, :, -len_seq:,
-len_seq:])
# TODO: adopt Cache mechanism in 4.36
if layer_out[1] is not None:
layer_out = layer(seq, use_cache=True,
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])
#print(f"layer:{type(layer)} layer_out:{layer_out}")
new_seq, (k_cache, v_cache) = layer_out
kv_cache_list[i][0].append(k_cache)
kv_cache_list[i][1].append(v_cache)
# TODO: adopt Cache mechanism in 4.36
new_seq, (k_cache, v_cache) = layer_out
kv_cache_list[i][0].append(k_cache)
kv_cache_list[i][1].append(v_cache)
# print(f"k_cache size: {k_cache.shape}")
# print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")