mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-04-28 17:40:01 +00:00
fix new_seq
This commit is contained in:
@@ -29,6 +29,18 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
bitsandbytes_installed = False
|
bitsandbytes_installed = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
|
|
||||||
|
cache_utils_installed = True
|
||||||
|
print('>>>> cache_utils installed')
|
||||||
|
except ImportError:
|
||||||
|
cache_utils_installed = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
total_disk_loading_time = None
|
total_disk_loading_time = None
|
||||||
total_gpu_loading_time = None
|
total_gpu_loading_time = None
|
||||||
total_compression_overhead_time = None
|
total_compression_overhead_time = None
|
||||||
@@ -111,14 +123,24 @@ class AirLLMLlama2(GenerationMixin):
|
|||||||
def init_model(self):
|
def init_model(self):
|
||||||
|
|
||||||
# Load meta model (no memory used)
|
# Load meta model (no memory used)
|
||||||
with init_empty_weights():
|
try:
|
||||||
self.model = AutoModelForCausalLM.from_config(self.config)
|
with init_empty_weights():
|
||||||
self.model.eval()
|
self.model = AutoModelForCausalLM.from_config(self.config)
|
||||||
try:
|
self.model.eval()
|
||||||
self.model = BetterTransformer.transform(self.model) # enable flash attention
|
self.model = BetterTransformer.transform(self.model) # enable flash attention
|
||||||
except ValueError as ve:
|
self.model.tie_weights()
|
||||||
print(f"new version of transfomer, no need to use BetterTransformer...")
|
except ValueError as ve:
|
||||||
self.model.tie_weights()
|
del self.model
|
||||||
|
clean_memory()
|
||||||
|
|
||||||
|
print(f"new version of transfomer, no need to use BetterTransformer, setting attn impl to sdpa...")
|
||||||
|
self.config.attn_implementation = "sdpa"
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
self.model = AutoModelForCausalLM.from_config(self.config, attn_implementation="sdpa")
|
||||||
|
self.model.eval()
|
||||||
|
self.model.tie_weights()
|
||||||
|
print(f"attn imp: {type(self.model.model.layers[3].self_attn)}")
|
||||||
|
|
||||||
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
|
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
|
||||||
self.model.lm_head]
|
self.model.lm_head]
|
||||||
@@ -211,6 +233,10 @@ class AirLLMLlama2(GenerationMixin):
|
|||||||
|
|
||||||
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
|
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
|
||||||
|
|
||||||
|
if cache_utils_installed:
|
||||||
|
# we don't support kv cache for new version yet
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if self.profiling_mode:
|
if self.profiling_mode:
|
||||||
total_disk_loading_time = []
|
total_disk_loading_time = []
|
||||||
total_gpu_loading_time = []
|
total_gpu_loading_time = []
|
||||||
@@ -318,17 +344,13 @@ class AirLLMLlama2(GenerationMixin):
|
|||||||
new_seq = layer(seq,
|
new_seq = layer(seq,
|
||||||
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0]
|
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])[0]
|
||||||
else:
|
else:
|
||||||
layer_out = layer(seq,
|
layer_out = layer(seq, use_cache=True,
|
||||||
use_cache=True,
|
attention_mask=attention_mask[:, :, -len_seq:, -len_seq:])
|
||||||
attention_mask=attention_mask[:, :, -len_seq:,
|
|
||||||
-len_seq:])
|
|
||||||
# TODO: adopt Cache mechanism in 4.36
|
|
||||||
if layer_out[1] is not None:
|
|
||||||
|
|
||||||
#print(f"layer:{type(layer)} layer_out:{layer_out}")
|
# TODO: adopt Cache mechanism in 4.36
|
||||||
new_seq, (k_cache, v_cache) = layer_out
|
new_seq, (k_cache, v_cache) = layer_out
|
||||||
kv_cache_list[i][0].append(k_cache)
|
kv_cache_list[i][0].append(k_cache)
|
||||||
kv_cache_list[i][1].append(v_cache)
|
kv_cache_list[i][1].append(v_cache)
|
||||||
|
|
||||||
# print(f"k_cache size: {k_cache.shape}")
|
# print(f"k_cache size: {k_cache.shape}")
|
||||||
# print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")
|
# print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")
|
||||||
|
|||||||
Reference in New Issue
Block a user