mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 14:24:44 +00:00
fix new_seq
This commit is contained in:
@@ -29,6 +29,18 @@ try:
|
||||
except ImportError:
|
||||
bitsandbytes_installed = False
|
||||
|
||||
|
||||
|
||||
try:
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
|
||||
cache_utils_installed = True
|
||||
print('>>>> cache_utils installed')
|
||||
except ImportError:
|
||||
cache_utils_installed = False
|
||||
|
||||
|
||||
|
||||
total_disk_loading_time = None
|
||||
total_gpu_loading_time = None
|
||||
total_compression_overhead_time = None
|
||||
@@ -111,14 +123,24 @@ class AirLLMLlama2(GenerationMixin):
|
||||
def init_model(self):
|
||||
|
||||
# Load meta model (no memory used)
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(self.config)
|
||||
self.model.eval()
|
||||
try:
|
||||
try:
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(self.config)
|
||||
self.model.eval()
|
||||
self.model = BetterTransformer.transform(self.model) # enable flash attention
|
||||
except ValueError as ve:
|
||||
print(f"new version of transfomer, no need to use BetterTransformer...")
|
||||
self.model.tie_weights()
|
||||
self.model.tie_weights()
|
||||
except ValueError as ve:
|
||||
del self.model
|
||||
clean_memory()
|
||||
|
||||
print(f"new version of transfomer, no need to use BetterTransformer, setting attn impl to sdpa...")
|
||||
self.config.attn_implementation = "sdpa"
|
||||
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa")
|
||||
self.model.eval()
|
||||
self.model.tie_weights()
|
||||
print(f"attn imp: {type(self.model.model.layers[3].self_attn)}")
|
||||
|
||||
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
|
||||
self.model.lm_head]
|
||||
@@ -211,6 +233,10 @@ class AirLLMLlama2(GenerationMixin):
|
||||
|
||||
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
|
||||
|
||||
if cache_utils_installed:
|
||||
# we don't support kv cache for new version yet
|
||||
use_cache = False
|
||||
|
||||
if self.profiling_mode:
|
||||
total_disk_loading_time = []
|
||||
total_gpu_loading_time = []
|
||||
@@ -318,17 +344,13 @@ class AirLLMLlama2(GenerationMixin):
|
||||
new_seq = layer(seq,
|
||||
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0]
|
||||
else:
|
||||
layer_out = layer(seq,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask[:, :, -len_seq:,
|
||||
-len_seq:])
|
||||
# TODO: adopt Cache mechanism in 4.36
|
||||
if layer_out[1] is not None:
|
||||
layer_out = layer(seq, use_cache=True,
|
||||
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])
|
||||
|
||||
#print(f"layer:{type(layer)} layer_out:{layer_out}")
|
||||
new_seq, (k_cache, v_cache) = layer_out
|
||||
kv_cache_list[i][0].append(k_cache)
|
||||
kv_cache_list[i][1].append(v_cache)
|
||||
# TODO: adopt Cache mechanism in 4.36
|
||||
new_seq, (k_cache, v_cache) = layer_out
|
||||
kv_cache_list[i][0].append(k_cache)
|
||||
kv_cache_list[i][1].append(v_cache)
|
||||
|
||||
# print(f"k_cache size: {k_cache.shape}")
|
||||
# print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")
|
||||
|
||||
Reference in New Issue
Block a user