From 97e96d8be56cd9f138dfec5147ea8a5aa4e49b15 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Sun, 17 Dec 2023 17:12:04 -0600 Subject: [PATCH] fix new_seq --- air_llm/airllm/airllm.py | 56 ++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index de39943..2205042 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -29,6 +29,18 @@ try: except ImportError: bitsandbytes_installed = False + + +try: + from transformers.cache_utils import Cache, DynamicCache + + cache_utils_installed = True + print('>>>> cache_utils installed') +except ImportError: + cache_utils_installed = False + + + total_disk_loading_time = None total_gpu_loading_time = None total_compression_overhead_time = None @@ -111,14 +123,24 @@ class AirLLMLlama2(GenerationMixin): def init_model(self): # Load meta model (no memory used) - with init_empty_weights(): - self.model = AutoModelForCausalLM.from_config(self.config) - self.model.eval() - try: + try: + with init_empty_weights(): + self.model = AutoModelForCausalLM.from_config(self.config) + self.model.eval() self.model = BetterTransformer.transform(self.model) # enable flash attention - except ValueError as ve: - print(f"new version of transfomer, no need to use BetterTransformer...") - self.model.tie_weights() + self.model.tie_weights() + except ValueError as ve: + del self.model + clean_memory() + + print(f"new version of transfomer, no need to use BetterTransformer, setting attn impl to sdpa...") + self.config.attn_implementation = "sdpa" + + with init_empty_weights(): + self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa") + self.model.eval() + self.model.tie_weights() + print(f"attn imp: {type(self.model.model.layers[3].self_attn)}") self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.model.lm_head] @@ -211,6 +233,10 @@ class AirLLMLlama2(GenerationMixin): global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time + if cache_utils_installed: + # we don't support kv cache for new version yet + use_cache = False + if self.profiling_mode: total_disk_loading_time = [] total_gpu_loading_time = [] @@ -318,17 +344,13 @@ class AirLLMLlama2(GenerationMixin): new_seq = layer(seq, attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0] else: - layer_out = layer(seq, - use_cache=True, - attention_mask=attention_mask[:, :, -len_seq:, - -len_seq:]) - # TODO: adopt Cache mechanism in 4.36 - if layer_out[1] is not None: + layer_out = layer(seq, use_cache=True, + attention_mask=attention_mask[:, :, -len_seq:, -len_seq:]) - #print(f"layer:{type(layer)} layer_out:{layer_out}") - new_seq, (k_cache, v_cache) = layer_out - kv_cache_list[i][0].append(k_cache) - kv_cache_list[i][1].append(v_cache) + # TODO: adopt Cache mechanism in 4.36 + new_seq, (k_cache, v_cache) = layer_out + kv_cache_list[i][0].append(k_cache) + kv_cache_list[i][1].append(v_cache) # print(f"k_cache size: {k_cache.shape}") # print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")