diff --git a/air_llm/README.md b/air_llm/README.md
index 96afcb7..b8b7ecd 100644
--- a/air_llm/README.md
+++ b/air_llm/README.md
@@ -7,9 +7,9 @@ AirLLM优化inference内存,4GB单卡GPU可以运行70B大语言模型推理
## Updates
-[2023/12/03] added support of **ChatGLM**, **QWen**!
+[2023/12/03] added support of **ChatGLM**, **QWen**, **Baichuan**, **Mistral**, **InternLM**!
-支持ChatGLM, QWEN!
+支持ChatGLM, QWEN, Baichuan, Mistral, InternLM!
[2023/12/02] added support for safetensors. Now support all top 10 models in open llm leaderboard.
@@ -148,6 +148,7 @@ When initialize the model, we support the following configurations:
| 8 | garage-bAInd/Platypus2-70B-instruct | ✅ | AirLLMLlama2 |
| 9 | jondurbin/airoboros-l2-70b-2.2.1 | ✅ | AirLLMLlama2 |
| 10 | chargoddard/Yi-34B-Llama | ✅ | AirLLMLlama2 |
+| ? | mistralai/Mistral-7B-Instruct-v0.1 | ✅ | AirLLMMistral |
#### [opencompass leaderboard](https://opencompass.org.cn/leaderboard-llm) top models
@@ -167,13 +168,14 @@ When initialize the model, we support the following configurations:
| 7 | OrionStarAI/OrionStar-Yi-34B-Chat | ✅ | AirLLMLlama2 |
| 8 | Qwen/Qwen-14B-Chat | ✅ | AirLLMQWen |
| 9 | Duxiaoman-DI/XuanYuan-70B | ✅ | AirLLMLlama2 |
-| 10 | internlm/internlm-20b | ⏰(adding, [to accelerate😀](https://bmc.link/lyogavinQ)) | |
-| 26 | baichuan-inc/Baichuan2-13B-Chat | ⏰(adding, [to accelerate😀](https://bmc.link/lyogavinQ)) | |
+| 10 | internlm/internlm-20b | ✅ | AirLLMInternLM |
+| 26 | baichuan-inc/Baichuan2-13B-Chat | ✅ | AirLLMBaichuan |
-#### example of other models (ChatGLM, QWen, etc):
+#### example of other models (ChatGLM, QWen, Baichuan, Mistral, etc):
+
* ChatGLM:
```python
@@ -215,6 +217,30 @@ generation_output = model.generate(
model.tokenizer.decode(generation_output.sequences[0])
```
+
+* Baichuan, InternLM, Mistral, etc:
+
+```python
+from airllm import AirLLMBaichuan # AirLLMInternLM, AirLLMMistral
+MAX_LENGTH = 128
+model = AirLLMBaichuan("baichuan-inc/Baichuan2-7B-Base")
+#model = AirLLMInternLM("internlm/internlm-20b")
+#model = AirLLMMistral("mistralai/Mistral-7B-Instruct-v0.1")
+input_text = ['What is the capital of China?',]
+input_tokens = model.tokenizer(input_text,
+ return_tensors="pt",
+ return_attention_mask=False,
+ truncation=True,
+ max_length=MAX_LENGTH)
+generation_output = model.generate(
+ input_tokens['input_ids'].cuda(),
+ max_new_tokens=5,
+ use_cache=True,
+ return_dict_in_generate=True)
+model.tokenizer.decode(generation_output.sequences[0])
+```
+
+
diff --git a/air_llm/airllm/__init__.py b/air_llm/airllm/__init__.py
index 60ab080..251ca52 100644
--- a/air_llm/airllm/__init__.py
+++ b/air_llm/airllm/__init__.py
@@ -1,5 +1,8 @@
from .airllm import AirLLMLlama2
from .airllm_chatglm import AirLLMChatGLM
from .airllm_qwen import AirLLMQWen
+from .airllm_baichuan import AirLLMBaichuan
+from .airllm_internlm import AirLLMInternLM
+from .airllm_mistral import AirLLMMistral
from .utils import split_and_save_layers
-from .utils import NotEnoughSpaceException
\ No newline at end of file
+from .utils import NotEnoughSpaceException
diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py
index 953f1be..cae4b37 100644
--- a/air_llm/airllm/airllm.py
+++ b/air_llm/airllm/airllm.py
@@ -278,7 +278,8 @@ class AirLLMLlama2(GenerationMixin):
output_attentions=output_attentions,
past_key_value=kv_cache,
position_ids=pos,
- attention_mask=attn)
+ attention_mask=attn
+ )
new_seq = layer_outputs[0]
if output_attentions:
@@ -322,7 +323,7 @@ class AirLLMLlama2(GenerationMixin):
for i in range(len(kv_cache_list)):
# print(f"{i} - {kv_cache_list[i][0].shape}")
kv_cache_list[i] = (torch.cat(kv_cache_list[i][0], 0), torch.cat(kv_cache_list[i][1], 0))
- print(f"returning kvcache size: {kv_cache_list[0][0].shape}")
+ #print(f"returning kvcache size: {kv_cache_list[0][0].shape}")
if output_attentions:
all_self_attns = all_self_attns[0:-2]
diff --git a/air_llm/airllm/airllm_baichuan.py b/air_llm/airllm/airllm_baichuan.py
new file mode 100644
index 0000000..7636776
--- /dev/null
+++ b/air_llm/airllm/airllm_baichuan.py
@@ -0,0 +1,399 @@
+import gc
+import json
+import os
+from typing import List, Optional, Tuple, Union
+import ctypes
+import shutil
+from tqdm import tqdm
+from pathlib import Path
+from glob import glob
+import time
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationMixin, LlamaForCausalLM, GenerationConfig
+
+from .tokenization_baichuan import BaichuanTokenizer
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from accelerate import init_empty_weights
+from accelerate.utils.modeling import set_module_tensor_to_device
+from safetensors.torch import load_file, save_file
+from optimum.bettertransformer import BetterTransformer
+import huggingface_hub
+
+from .utils import save_quant_state_to_dict, NotEnoughSpaceException, clean_memory, uncompress_layer_state_dict, load_layer, \
+ check_space, compress_layer_state_dict, split_and_save_layers, find_or_create_local_splitted_path
+
+try:
+ import bitsandbytes as bnb
+
+ bitsandbytes_installed = True
+ print('>>>> bitsandbytes installed')
+except ImportError:
+ bitsandbytes_installed = False
+
+total_disk_loading_time = None
+total_gpu_loading_time = None
+total_compression_overhead_time = None
+
+
+
+class AirLLMBaichuan(GenerationMixin):
+ def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512,
+ layer_shards_saving_path=None, profiling_mode=False, compression=None):
+ """
+ Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
+ During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
+ To avoid loading the layers multiple times, we could save all the intermediate activations in RAM.
+
+ Parameters
+ ----------
+ model_local_path_or_repo_id : str or Path
+ path to the local model checkpoint or huggingface repo id
+ device : str, optional
+ device, by default "cuda:0"
+ dtype : torch.dtype, optional
+ dtype, by default torch.float16
+ max_seq_len : int, optional
+ max seq lenght, by default 512
+ layer_shards_saving_path : str, optional
+ optional path to save layered shards model file, by default just save to the local cache of model, subdir named splitted_model will be saved
+ profiling_mode : book, optional
+ if to profile the model loading time, default to False
+ compression: str, optinal
+ setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss.
+ """
+
+
+ self.profiling_mode = profiling_mode
+
+ if compression is not None:
+ if not bitsandbytes_installed:
+ raise ImportError('WARNING: bitsandbytes not found. Compression needs bitsandbytes. To use compression, please install bitsandbytes: `pip install bitsandbytes`')
+
+
+ self.compression = compression
+
+ # Save parameters
+
+ self.layer_names_dict = {'embed': 'model.embed_tokens',
+ 'layer_prefix': 'model.layers',
+ 'norm': 'model.norm',
+ 'lm_head': 'lm_head',}
+ self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id,
+ layer_shards_saving_path,
+ compression=compression,
+ layer_names=self.layer_names_dict)
+ self.running_device = device
+ self.device = torch.device(self.running_device)
+ self.running_dtype = dtype
+ self.dtype = self.running_dtype
+
+ # Create model
+ self.config = AutoConfig.from_pretrained(self.model_local_path, trust_remote_code=True)
+ self.generation_config = GenerationConfig()#GenerationConfig.from_pretrained(self.model_local_path)
+ #print(f"using generation_config: {self.generation_config}")
+
+ # use this hack util the bug is fixed: https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/discussions/2
+ self.tokenizer = BaichuanTokenizer.from_pretrained(self.model_local_path, use_fast=False, trust_remote_code=True)
+ #self.tokenizer.pad_token = self.tokenizer.eos_token
+ #self.tokenizer.padding_side = "right"
+ self.init_model()
+ self.layer_names = [self.layer_names_dict['embed']] + [f'{self.layer_names_dict["layer_prefix"]}.{i}' for i in range(len(self.model.model.layers))] + \
+ [self.layer_names_dict['norm'], self.layer_names_dict['lm_head']]
+
+ self.max_seq_len = max_seq_len
+
+ self.main_input_name = "input_ids"
+
+ def init_model(self):
+
+ # Load meta model (no memory used)
+ with init_empty_weights():
+ self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
+ self.model.eval()
+ #self.model = BetterTransformer.transform(self.model) # enable flash attention
+ self.model.tie_weights()
+
+ self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
+ self.model.lm_head]
+
+ # Move buffers to device (not that much GPU memory used)
+ for buffer_name, buffer in self.model.named_buffers():
+ set_module_tensor_to_device(self.model, buffer_name, self.running_device, value=buffer,
+ dtype=self.running_dtype)
+
+ if 'rotary_pos_emb' in self.layer_names_dict:
+ # for glm keep rotary_pos_emb in gpu
+ self.load_rotary_pos_emb_to_device()
+
+ def load_rotary_pos_emb_to_device(self):
+ state_dict = load_layer(self.checkpoint_path, self.layer_names_dict['layer_names_dict'])
+ self.move_layer_to_device(state_dict)
+
+ def load_layer_to_cpu(self, layer_name, profiling=False):
+
+ t = time.process_time()
+ load_layer_output = load_layer(self.checkpoint_path, layer_name, profiling)
+ elapsed_time = time.process_time() - t
+
+ if profiling:
+ state_dict, compression_time = load_layer_output
+ disk_loading_time = elapsed_time - compression_time
+ return state_dict, disk_loading_time, compression_time
+ else:
+ state_dict = load_layer_output
+
+ return state_dict
+
+ def move_layer_to_device(self, state_dict):
+ for param_name, param in state_dict.items():
+ #assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
+ set_module_tensor_to_device(self.model, param_name, self.running_device, value=param,
+ dtype=self.running_dtype)
+
+ # make GenerationMixin happy
+ def can_generate(self):
+ return True
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ #print(f"input_ids shape: {input_ids.shape}")
+
+ global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
+
+ if self.profiling_mode:
+ total_disk_loading_time = []
+ total_gpu_loading_time = []
+ total_compression_overhead_time = []
+ forward_start = time.process_time()
+
+ # Reboot the model to make sure buffers are loaded and memory is clean
+ del self.model
+ clean_memory()
+ self.init_model()
+
+ batch = [input_ids_unit.to(self.running_device).unsqueeze(0) for input_ids_unit in input_ids]
+ n_seq = len(batch[0])
+ #print(f"batch[0] shape:{batch[0].shape}")
+ #batch_eos = [(input_ids_unit != self.tokenizer.pad_token_id).sum(0) - 1 for input_ids_unit in input_ids]
+
+ # Create attention mask for the largest input, and position ids to use KV cache
+ attention_mask = torch.ones(self.max_seq_len, self.max_seq_len)
+ attention_mask = attention_mask.triu(diagonal=1)[None, None, ...] == 0
+ attention_mask = attention_mask.to(self.running_device)
+ position_ids = torch.arange(self.max_seq_len, dtype=torch.long, device=self.running_device)[None, :]
+
+ kv_cache_list = [] if use_cache else None
+ if use_cache:
+ for x in self.layers:
+ kv_cache_list.append(([], []))
+ all_hidden_states = [] * len(self.layers) if output_hidden_states else None
+ all_self_attns = None
+
+ with torch.inference_mode():
+
+ for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.running_device,
+ total=len(self.layers)):
+ #print(f"layer:{i} {layer_name}")
+
+ load_layer_to_cpu_output = self.load_layer_to_cpu(layer_name, self.profiling_mode)
+ # profile
+ if self.profiling_mode:
+ state_dict, disk_loading_time, compression_time = load_layer_to_cpu_output
+ total_disk_loading_time.append(disk_loading_time)
+ total_compression_overhead_time.append(compression_time)
+ else:
+ state_dict = load_layer_to_cpu_output
+
+ t = time.process_time()
+ self.move_layer_to_device(state_dict)
+ elapsed_time = time.process_time() - t
+ # profile
+ if self.profiling_mode:
+ total_gpu_loading_time.append(elapsed_time)
+
+ # Run layer
+
+ for j, seq in enumerate(batch):
+ #print(f"{j}th in batch shape: {seq.shape}")
+
+ if layer_name == self.layer_names_dict['embed']:
+ batch[j] = layer(seq)
+ elif layer_name == self.layer_names_dict['norm']:
+ #batch[j] = layer(seq[torch.arange(n_seq), batch_eos[j]][:, None])
+ batch[j] = layer(seq)
+
+ if output_attentions:
+ all_hidden_states[i].append(batch[j])
+ elif layer_name == self.layer_names_dict['lm_head']:
+ batch[j] = layer(seq).float()
+ else:
+
+ if output_attentions:
+ all_hidden_states[i].append(new_seq)
+
+ if past_key_values is not None:
+ #print(f"len past_key_values: {len(past_key_values)}, past_key_values[0][0] shape:{past_key_values[0][0].shape}")
+ # join past kv
+ k_cache, v_cache = past_key_values[i - 1]
+ len_p = past_key_values[0][0].shape[2]
+ len_s = seq.shape[1]
+
+ pos = position_ids[:, len_p:len_p + len_s]
+
+ attn = attention_mask[:, :, -len_s:, -len_p - len_s:]
+ kv_cache = (k_cache,
+ v_cache,
+ )
+
+ layer_outputs = layer(seq,
+ use_cache=True,
+ output_attentions=output_attentions,
+ past_key_value=kv_cache,
+ position_ids=pos,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attn
+ )
+ new_seq = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns[i].append(layer_outputs[1])
+
+ if use_cache:
+ (k_cache, v_cache) = layer_outputs[2 if output_attentions else 1]
+ kv_cache_list[i][0].append(k_cache)
+ kv_cache_list[i][1].append(v_cache)
+
+
+ else:
+ len_seq = seq.shape[1]
+
+
+ if not use_cache:
+ new_seq = layer(seq,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attention_mask[:, :, -len_seq:, -len_seq:]
+ )[0]
+ else:
+ new_seq, (k_cache, v_cache) = layer(seq,
+ use_cache=True,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attention_mask[:, :, -len_seq:,
+ -len_seq:]
+ )
+ kv_cache_list[i][0].append(k_cache)
+ kv_cache_list[i][1].append(v_cache)
+
+ # print(f"k_cache size: {k_cache.shape}")
+ # print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")
+
+ batch[j] = new_seq
+
+ if output_hidden_states:
+ all_hidden_states += (torch.cat(batch, 0),)
+
+ # Remove previous layer from memory (including buffers)
+ layer.to("meta")
+ clean_memory() # proposed by CPMP
+
+ logits = torch.cat(batch, 0)
+ if use_cache:
+ kv_cache_list = kv_cache_list[1:-2]
+ for i in range(len(kv_cache_list)):
+ # print(f"{i} - {kv_cache_list[i][0].shape}")
+ kv_cache_list[i] = (torch.cat(kv_cache_list[i][0], 0), torch.cat(kv_cache_list[i][1], 0))
+ #print(f"returning kvcache size: {kv_cache_list[0][0].shape}")
+
+ if output_attentions:
+ all_self_attns = all_self_attns[0:-2]
+ for i in range(len(all_self_attns)):
+ all_self_attns[i] = torch.cat(all_self_attns[i], 0)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states[0:-2]
+ for i in range(len(all_hidden_states)):
+ all_hidden_states[i] = torch.cat(all_hidden_states[i], 0)
+
+ if not return_dict:
+ return tuple(v for v in [logits,
+ tuple(kv_cache_list) if kv_cache_list is not None else None,
+ tuple(all_hidden_states) if all_hidden_states is not None else None,
+ tuple(all_self_attns) if all_self_attns is not None else None] if v is not None)
+ if self.profiling_mode:
+
+ forward_elapsed_time = time.process_time() - forward_start
+
+ if self.compression:
+ print(f"total disk loading time: {sum(total_disk_loading_time):.04f}")
+ print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}")
+ print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}")
+ else:
+ # loading is async/lazy, so can't really distinguish them...
+ print(f"total disk+gpu loading time: {sum(total_disk_loading_time) + sum(total_gpu_loading_time):.04f}")
+ print(f"total infer time(including all above plus gpu compute): {forward_elapsed_time:.04f}")
+
+ total_disk_loading_time = []
+ total_gpu_loading_time = []
+ total_compression_overhead_time = []
+
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=logits,
+ past_key_values=tuple(kv_cache_list) if kv_cache_list is not None else None,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states is not None else None,
+ attentions=tuple(all_self_attns) if all_hidden_states is not None else None,
+ )
\ No newline at end of file
diff --git a/air_llm/airllm/airllm_internlm.py b/air_llm/airllm/airllm_internlm.py
new file mode 100644
index 0000000..6d0e690
--- /dev/null
+++ b/air_llm/airllm/airllm_internlm.py
@@ -0,0 +1,398 @@
+import gc
+import json
+import os
+from typing import List, Optional, Tuple, Union
+import ctypes
+import shutil
+from tqdm import tqdm
+from pathlib import Path
+from glob import glob
+import time
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationMixin, LlamaForCausalLM, GenerationConfig
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from accelerate import init_empty_weights
+from accelerate.utils.modeling import set_module_tensor_to_device
+from safetensors.torch import load_file, save_file
+from optimum.bettertransformer import BetterTransformer
+import huggingface_hub
+
+from .utils import save_quant_state_to_dict, NotEnoughSpaceException, clean_memory, uncompress_layer_state_dict, load_layer, \
+ check_space, compress_layer_state_dict, split_and_save_layers, find_or_create_local_splitted_path
+
+try:
+ import bitsandbytes as bnb
+
+ bitsandbytes_installed = True
+ print('>>>> bitsandbytes installed')
+except ImportError:
+ bitsandbytes_installed = False
+
+total_disk_loading_time = None
+total_gpu_loading_time = None
+total_compression_overhead_time = None
+
+
+
+class AirLLMInternLM(GenerationMixin):
+ def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512,
+ layer_shards_saving_path=None, profiling_mode=False, compression=None):
+ """
+ Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
+ During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
+ To avoid loading the layers multiple times, we could save all the intermediate activations in RAM.
+
+ Parameters
+ ----------
+ model_local_path_or_repo_id : str or Path
+ path to the local model checkpoint or huggingface repo id
+ device : str, optional
+ device, by default "cuda:0"
+ dtype : torch.dtype, optional
+ dtype, by default torch.float16
+ max_seq_len : int, optional
+ max seq lenght, by default 512
+ layer_shards_saving_path : str, optional
+ optional path to save layered shards model file, by default just save to the local cache of model, subdir named splitted_model will be saved
+ profiling_mode : book, optional
+ if to profile the model loading time, default to False
+ compression: str, optinal
+ setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss.
+ """
+
+
+ self.profiling_mode = profiling_mode
+
+ if compression is not None:
+ if not bitsandbytes_installed:
+ raise ImportError('WARNING: bitsandbytes not found. Compression needs bitsandbytes. To use compression, please install bitsandbytes: `pip install bitsandbytes`')
+
+
+ self.compression = compression
+
+ # Save parameters
+
+ self.layer_names_dict = {'embed': 'model.embed_tokens',
+ 'layer_prefix': 'model.layers',
+ 'norm': 'model.norm',
+ 'lm_head': 'lm_head',}
+ self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id,
+ layer_shards_saving_path,
+ compression=compression,
+ layer_names=self.layer_names_dict)
+ self.running_device = device
+ self.device = torch.device(self.running_device)
+ self.running_dtype = dtype
+ self.dtype = self.running_dtype
+
+ # Create model
+ self.config = AutoConfig.from_pretrained(self.model_local_path, trust_remote_code=True)
+ self.generation_config = GenerationConfig()#GenerationConfig.from_pretrained(self.model_local_path)
+ #print(f"using generation_config: {self.generation_config}")
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_local_path, trust_remote_code=True)
+ #self.tokenizer.pad_token = self.tokenizer.eos_token
+ #self.tokenizer.padding_side = "right"
+ self.init_model()
+ self.layer_names = [self.layer_names_dict['embed']] + [f'{self.layer_names_dict["layer_prefix"]}.{i}' for i in range(len(self.model.model.layers))] + \
+ [self.layer_names_dict['norm'], self.layer_names_dict['lm_head']]
+
+ self.max_seq_len = max_seq_len
+
+ self.main_input_name = "input_ids"
+
+ def init_model(self):
+
+ # Load meta model (no memory used)
+ with init_empty_weights():
+ self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
+ self.model.eval()
+ #self.model = BetterTransformer.transform(self.model) # enable flash attention
+ self.model.tie_weights()
+
+ self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
+ self.model.lm_head]
+
+ # Move buffers to device (not that much GPU memory used)
+ for buffer_name, buffer in self.model.named_buffers():
+ set_module_tensor_to_device(self.model, buffer_name, self.running_device, value=buffer,
+ dtype=self.running_dtype)
+
+ if 'rotary_pos_emb' in self.layer_names_dict:
+ # for glm keep rotary_pos_emb in gpu
+ self.load_rotary_pos_emb_to_device()
+
+ def load_rotary_pos_emb_to_device(self):
+ state_dict = load_layer(self.checkpoint_path, self.layer_names_dict['layer_names_dict'])
+ self.move_layer_to_device(state_dict)
+
+ def load_layer_to_cpu(self, layer_name, profiling=False):
+
+ t = time.process_time()
+ load_layer_output = load_layer(self.checkpoint_path, layer_name, profiling)
+ elapsed_time = time.process_time() - t
+
+ if profiling:
+ state_dict, compression_time = load_layer_output
+ disk_loading_time = elapsed_time - compression_time
+ return state_dict, disk_loading_time, compression_time
+ else:
+ state_dict = load_layer_output
+
+ return state_dict
+
+ def move_layer_to_device(self, state_dict):
+ for param_name, param in state_dict.items():
+ #assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
+ set_module_tensor_to_device(self.model, param_name, self.running_device, value=param,
+ dtype=self.running_dtype)
+
+ # make GenerationMixin happy
+ def can_generate(self):
+ return True
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ #print(f"input_ids shape: {input_ids.shape}")
+
+ global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
+
+ if self.profiling_mode:
+ total_disk_loading_time = []
+ total_gpu_loading_time = []
+ total_compression_overhead_time = []
+ forward_start = time.process_time()
+
+ # Reboot the model to make sure buffers are loaded and memory is clean
+ del self.model
+ clean_memory()
+ self.init_model()
+
+ batch = [input_ids_unit.to(self.running_device).unsqueeze(0) for input_ids_unit in input_ids]
+ n_seq = len(batch[0])
+ #print(f"batch[0] shape:{batch[0].shape}")
+ #batch_eos = [(input_ids_unit != self.tokenizer.pad_token_id).sum(0) - 1 for input_ids_unit in input_ids]
+
+ # Create attention mask for the largest input, and position ids to use KV cache
+ attention_mask = torch.ones(self.max_seq_len, self.max_seq_len)
+ attention_mask = attention_mask.triu(diagonal=1)[None, None, ...] == 0
+ attention_mask = attention_mask.to(self.running_device)
+ position_ids = torch.arange(self.max_seq_len, dtype=torch.long, device=self.running_device)[None, :]
+
+ kv_cache_list = [] if use_cache else None
+ if use_cache:
+ for x in self.layers:
+ kv_cache_list.append(([], []))
+ all_hidden_states = [] * len(self.layers) if output_hidden_states else None
+ all_self_attns = None
+
+ with torch.inference_mode():
+
+ for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.running_device,
+ total=len(self.layers)):
+ #print(f"layer:{i} {layer_name}")
+
+ load_layer_to_cpu_output = self.load_layer_to_cpu(layer_name, self.profiling_mode)
+ # profile
+ if self.profiling_mode:
+ state_dict, disk_loading_time, compression_time = load_layer_to_cpu_output
+ total_disk_loading_time.append(disk_loading_time)
+ total_compression_overhead_time.append(compression_time)
+ else:
+ state_dict = load_layer_to_cpu_output
+
+ t = time.process_time()
+ self.move_layer_to_device(state_dict)
+ elapsed_time = time.process_time() - t
+ # profile
+ if self.profiling_mode:
+ total_gpu_loading_time.append(elapsed_time)
+
+ # Run layer
+
+ for j, seq in enumerate(batch):
+ #print(f"{j}th in batch shape: {seq.shape}")
+
+ if layer_name == self.layer_names_dict['embed']:
+ batch[j] = layer(seq)
+ elif layer_name == self.layer_names_dict['norm']:
+ #batch[j] = layer(seq[torch.arange(n_seq), batch_eos[j]][:, None])
+ batch[j] = layer(seq)
+
+ if output_attentions:
+ all_hidden_states[i].append(batch[j])
+ elif layer_name == self.layer_names_dict['lm_head']:
+ batch[j] = layer(seq).float()
+ else:
+
+ if output_attentions:
+ all_hidden_states[i].append(new_seq)
+
+ if past_key_values is not None:
+ #print(f"len past_key_values: {len(past_key_values)}, past_key_values[0][0] shape:{past_key_values[0][0].shape}")
+ # join past kv
+ k_cache, v_cache = past_key_values[i - 1]
+ len_p = past_key_values[0][0].shape[2]
+ len_s = seq.shape[1]
+
+ pos = position_ids[:, len_p:len_p + len_s]
+
+ attn = attention_mask[:, :, -len_s:, -len_p - len_s:]
+ kv_cache = (k_cache,
+ v_cache,
+ )
+
+ layer_outputs = layer(seq,
+ use_cache=True,
+ output_attentions=output_attentions,
+ past_key_value=kv_cache,
+ position_ids=pos,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attn
+ )
+ new_seq = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns[i].append(layer_outputs[1])
+
+ if use_cache:
+ (k_cache, v_cache) = layer_outputs[1]
+ kv_cache_list[i][0].append(k_cache)
+ kv_cache_list[i][1].append(v_cache)
+
+
+ else:
+ len_seq = seq.shape[1]
+ pos = position_ids[:, :len_seq]
+
+
+ if not use_cache:
+ new_seq = layer(seq,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ position_ids=pos,
+ attention_mask=attention_mask[:, :, -len_seq:, -len_seq:]
+ )[0]
+ else:
+ new_seq, (k_cache, v_cache) = layer(seq,
+ use_cache=True,
+ position_ids=pos,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attention_mask[:, :, -len_seq:,
+ -len_seq:]
+ )
+ kv_cache_list[i][0].append(k_cache)
+ kv_cache_list[i][1].append(v_cache)
+
+ # print(f"k_cache size: {k_cache.shape}")
+ # print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")
+
+ batch[j] = new_seq
+
+ if output_hidden_states:
+ all_hidden_states += (torch.cat(batch, 0),)
+
+ # Remove previous layer from memory (including buffers)
+ layer.to("meta")
+ clean_memory() # proposed by CPMP
+
+ logits = torch.cat(batch, 0)
+ if use_cache:
+ kv_cache_list = kv_cache_list[1:-2]
+ for i in range(len(kv_cache_list)):
+ # print(f"{i} - {kv_cache_list[i][0].shape}")
+ kv_cache_list[i] = (torch.cat(kv_cache_list[i][0], 0), torch.cat(kv_cache_list[i][1], 0))
+ #print(f"returning kvcache size: {kv_cache_list[0][0].shape}")
+
+ if output_attentions:
+ all_self_attns = all_self_attns[0:-2]
+ for i in range(len(all_self_attns)):
+ all_self_attns[i] = torch.cat(all_self_attns[i], 0)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states[0:-2]
+ for i in range(len(all_hidden_states)):
+ all_hidden_states[i] = torch.cat(all_hidden_states[i], 0)
+
+ if not return_dict:
+ return tuple(v for v in [logits,
+ tuple(kv_cache_list) if kv_cache_list is not None else None,
+ tuple(all_hidden_states) if all_hidden_states is not None else None,
+ tuple(all_self_attns) if all_self_attns is not None else None] if v is not None)
+ if self.profiling_mode:
+
+ forward_elapsed_time = time.process_time() - forward_start
+
+ if self.compression:
+ print(f"total disk loading time: {sum(total_disk_loading_time):.04f}")
+ print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}")
+ print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}")
+ else:
+ # loading is async/lazy, so can't really distinguish them...
+ print(f"total disk+gpu loading time: {sum(total_disk_loading_time) + sum(total_gpu_loading_time):.04f}")
+ print(f"total infer time(including all above plus gpu compute): {forward_elapsed_time:.04f}")
+
+ total_disk_loading_time = []
+ total_gpu_loading_time = []
+ total_compression_overhead_time = []
+
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=logits,
+ past_key_values=tuple(kv_cache_list) if kv_cache_list is not None else None,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states is not None else None,
+ attentions=tuple(all_self_attns) if all_hidden_states is not None else None,
+ )
\ No newline at end of file
diff --git a/air_llm/airllm/airllm_mistral.py b/air_llm/airllm/airllm_mistral.py
new file mode 100644
index 0000000..c59c13a
--- /dev/null
+++ b/air_llm/airllm/airllm_mistral.py
@@ -0,0 +1,395 @@
+import gc
+import json
+import os
+from typing import List, Optional, Tuple, Union
+import ctypes
+import shutil
+from tqdm import tqdm
+from pathlib import Path
+from glob import glob
+import time
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel, GenerationMixin, LlamaForCausalLM, GenerationConfig
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from accelerate import init_empty_weights
+from accelerate.utils.modeling import set_module_tensor_to_device
+from safetensors.torch import load_file, save_file
+from optimum.bettertransformer import BetterTransformer
+import huggingface_hub
+
+from .utils import save_quant_state_to_dict, NotEnoughSpaceException, clean_memory, uncompress_layer_state_dict, load_layer, \
+ check_space, compress_layer_state_dict, split_and_save_layers, find_or_create_local_splitted_path
+
+try:
+ import bitsandbytes as bnb
+
+ bitsandbytes_installed = True
+ print('>>>> bitsandbytes installed')
+except ImportError:
+ bitsandbytes_installed = False
+
+total_disk_loading_time = None
+total_gpu_loading_time = None
+total_compression_overhead_time = None
+
+
+
+class AirLLMMistral(GenerationMixin):
+ def __init__(self, model_local_path_or_repo_id, device="cuda:0", dtype=torch.float16, max_seq_len=512,
+ layer_shards_saving_path=None, profiling_mode=False, compression=None):
+ """
+ Sharded version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
+ During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
+ To avoid loading the layers multiple times, we could save all the intermediate activations in RAM.
+
+ Parameters
+ ----------
+ model_local_path_or_repo_id : str or Path
+ path to the local model checkpoint or huggingface repo id
+ device : str, optional
+ device, by default "cuda:0"
+ dtype : torch.dtype, optional
+ dtype, by default torch.float16
+ max_seq_len : int, optional
+ max seq lenght, by default 512
+ layer_shards_saving_path : str, optional
+ optional path to save layered shards model file, by default just save to the local cache of model, subdir named splitted_model will be saved
+ profiling_mode : book, optional
+ if to profile the model loading time, default to False
+ compression: str, optinal
+ setting to '4bit' or '8bit' to enable compression from 16 bits to 4 bits/8 bits which speeed up 4x or 2x inference time with a tiny accuracy loss.
+ """
+
+
+ self.profiling_mode = profiling_mode
+
+ if compression is not None:
+ if not bitsandbytes_installed:
+ raise ImportError('WARNING: bitsandbytes not found. Compression needs bitsandbytes. To use compression, please install bitsandbytes: `pip install bitsandbytes`')
+
+
+ self.compression = compression
+
+ # Save parameters
+
+ self.layer_names_dict = {'embed': 'model.embed_tokens',
+ 'layer_prefix': 'model.layers',
+ 'norm': 'model.norm',
+ 'lm_head': 'lm_head',}
+ self.model_local_path, self.checkpoint_path = find_or_create_local_splitted_path(model_local_path_or_repo_id,
+ layer_shards_saving_path,
+ compression=compression,
+ layer_names=self.layer_names_dict)
+ self.running_device = device
+ self.device = torch.device(self.running_device)
+ self.running_dtype = dtype
+ self.dtype = self.running_dtype
+
+ # Create model
+ self.config = AutoConfig.from_pretrained(self.model_local_path, trust_remote_code=True)
+ self.generation_config = GenerationConfig()#GenerationConfig.from_pretrained(self.model_local_path)
+ #print(f"using generation_config: {self.generation_config}")
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_local_path, trust_remote_code=True)
+ #self.tokenizer.pad_token = self.tokenizer.eos_token
+ #self.tokenizer.padding_side = "right"
+ self.init_model()
+ self.layer_names = [self.layer_names_dict['embed']] + [f'{self.layer_names_dict["layer_prefix"]}.{i}' for i in range(len(self.model.model.layers))] + \
+ [self.layer_names_dict['norm'], self.layer_names_dict['lm_head']]
+
+ self.max_seq_len = max_seq_len
+
+ self.main_input_name = "input_ids"
+
+ def init_model(self):
+
+ # Load meta model (no memory used)
+ with init_empty_weights():
+ self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
+ self.model.eval()
+ #self.model = BetterTransformer.transform(self.model) # enable flash attention
+ self.model.tie_weights()
+
+ self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm,
+ self.model.lm_head]
+
+ # Move buffers to device (not that much GPU memory used)
+ for buffer_name, buffer in self.model.named_buffers():
+ set_module_tensor_to_device(self.model, buffer_name, self.running_device, value=buffer,
+ dtype=self.running_dtype)
+
+ if 'rotary_pos_emb' in self.layer_names_dict:
+ # for glm keep rotary_pos_emb in gpu
+ self.load_rotary_pos_emb_to_device()
+
+ def load_rotary_pos_emb_to_device(self):
+ state_dict = load_layer(self.checkpoint_path, self.layer_names_dict['layer_names_dict'])
+ self.move_layer_to_device(state_dict)
+
+ def load_layer_to_cpu(self, layer_name, profiling=False):
+
+ t = time.process_time()
+ load_layer_output = load_layer(self.checkpoint_path, layer_name, profiling)
+ elapsed_time = time.process_time() - t
+
+ if profiling:
+ state_dict, compression_time = load_layer_output
+ disk_loading_time = elapsed_time - compression_time
+ return state_dict, disk_loading_time, compression_time
+ else:
+ state_dict = load_layer_output
+
+ return state_dict
+
+ def move_layer_to_device(self, state_dict):
+ for param_name, param in state_dict.items():
+ #assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)"
+ set_module_tensor_to_device(self.model, param_name, self.running_device, value=param,
+ dtype=self.running_dtype)
+
+ # make GenerationMixin happy
+ def can_generate(self):
+ return True
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ #print(f"input_ids shape: {input_ids.shape}")
+
+ global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
+
+ if self.profiling_mode:
+ total_disk_loading_time = []
+ total_gpu_loading_time = []
+ total_compression_overhead_time = []
+ forward_start = time.process_time()
+
+ # Reboot the model to make sure buffers are loaded and memory is clean
+ del self.model
+ clean_memory()
+ self.init_model()
+
+ batch = [input_ids_unit.to(self.running_device).unsqueeze(0) for input_ids_unit in input_ids]
+ n_seq = len(batch[0])
+ #print(f"batch[0] shape:{batch[0].shape}")
+ #batch_eos = [(input_ids_unit != self.tokenizer.pad_token_id).sum(0) - 1 for input_ids_unit in input_ids]
+
+ # Create attention mask for the largest input, and position ids to use KV cache
+ attention_mask = torch.ones(self.max_seq_len, self.max_seq_len)
+ attention_mask = attention_mask.triu(diagonal=1)[None, None, ...] == 0
+ attention_mask = attention_mask.to(self.running_device)
+ position_ids = torch.arange(self.max_seq_len, dtype=torch.long, device=self.running_device)[None, :]
+
+ kv_cache_list = [] if use_cache else None
+ if use_cache:
+ for x in self.layers:
+ kv_cache_list.append(([], []))
+ all_hidden_states = [] * len(self.layers) if output_hidden_states else None
+ all_self_attns = None
+
+ with torch.inference_mode():
+
+ for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.running_device,
+ total=len(self.layers)):
+ #print(f"layer:{i} {layer_name}")
+
+ load_layer_to_cpu_output = self.load_layer_to_cpu(layer_name, self.profiling_mode)
+ # profile
+ if self.profiling_mode:
+ state_dict, disk_loading_time, compression_time = load_layer_to_cpu_output
+ total_disk_loading_time.append(disk_loading_time)
+ total_compression_overhead_time.append(compression_time)
+ else:
+ state_dict = load_layer_to_cpu_output
+
+ t = time.process_time()
+ self.move_layer_to_device(state_dict)
+ elapsed_time = time.process_time() - t
+ # profile
+ if self.profiling_mode:
+ total_gpu_loading_time.append(elapsed_time)
+
+ # Run layer
+
+ for j, seq in enumerate(batch):
+ #print(f"{j}th in batch shape: {seq.shape}")
+
+ if layer_name == self.layer_names_dict['embed']:
+ batch[j] = layer(seq)
+ elif layer_name == self.layer_names_dict['norm']:
+ #batch[j] = layer(seq[torch.arange(n_seq), batch_eos[j]][:, None])
+ batch[j] = layer(seq)
+
+ if output_attentions:
+ all_hidden_states[i].append(batch[j])
+ elif layer_name == self.layer_names_dict['lm_head']:
+ batch[j] = layer(seq).float()
+ else:
+
+ if output_attentions:
+ all_hidden_states[i].append(new_seq)
+
+ if past_key_values is not None:
+ #print(f"len past_key_values: {len(past_key_values)}, past_key_values[0][0] shape:{past_key_values[0][0].shape}")
+ # join past kv
+ k_cache, v_cache = past_key_values[i - 1]
+ len_p = past_key_values[0][0].shape[2]
+ len_s = seq.shape[1]
+
+ pos = position_ids[:, len_p:len_p + len_s]
+
+ attn = attention_mask[:, :, -len_s:, -len_p - len_s:]
+ kv_cache = (k_cache,
+ v_cache,
+ )
+
+ layer_outputs = layer(seq,
+ use_cache=True,
+ output_attentions=output_attentions,
+ past_key_value=kv_cache,
+ position_ids=pos,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attn
+ )
+ new_seq = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns[i].append(layer_outputs[1])
+
+ if use_cache:
+ (k_cache, v_cache) = layer_outputs[1]
+ kv_cache_list[i][0].append(k_cache)
+ kv_cache_list[i][1].append(v_cache)
+
+
+ else:
+ len_seq = seq.shape[1]
+
+
+ if not use_cache:
+ new_seq = layer(seq,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attention_mask[:, :, -len_seq:, -len_seq:]
+ )[0]
+ else:
+ new_seq, (k_cache, v_cache) = layer(seq,
+ use_cache=True,
+ #rotary_pos_emb_list=rotary_pos_emb_list,
+ attention_mask=attention_mask[:, :, -len_seq:,
+ -len_seq:]
+ )
+ kv_cache_list[i][0].append(k_cache)
+ kv_cache_list[i][1].append(v_cache)
+
+ # print(f"k_cache size: {k_cache.shape}")
+ # print(f"k_cache sizes: {[len(x[1]) for x in kv_cache_list]}")
+
+ batch[j] = new_seq
+
+ if output_hidden_states:
+ all_hidden_states += (torch.cat(batch, 0),)
+
+ # Remove previous layer from memory (including buffers)
+ layer.to("meta")
+ clean_memory() # proposed by CPMP
+
+ logits = torch.cat(batch, 0)
+ if use_cache:
+ kv_cache_list = kv_cache_list[1:-2]
+ for i in range(len(kv_cache_list)):
+ # print(f"{i} - {kv_cache_list[i][0].shape}")
+ kv_cache_list[i] = (torch.cat(kv_cache_list[i][0], 0), torch.cat(kv_cache_list[i][1], 0))
+ #print(f"returning kvcache size: {kv_cache_list[0][0].shape}")
+
+ if output_attentions:
+ all_self_attns = all_self_attns[0:-2]
+ for i in range(len(all_self_attns)):
+ all_self_attns[i] = torch.cat(all_self_attns[i], 0)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states[0:-2]
+ for i in range(len(all_hidden_states)):
+ all_hidden_states[i] = torch.cat(all_hidden_states[i], 0)
+
+ if not return_dict:
+ return tuple(v for v in [logits,
+ tuple(kv_cache_list) if kv_cache_list is not None else None,
+ tuple(all_hidden_states) if all_hidden_states is not None else None,
+ tuple(all_self_attns) if all_self_attns is not None else None] if v is not None)
+ if self.profiling_mode:
+
+ forward_elapsed_time = time.process_time() - forward_start
+
+ if self.compression:
+ print(f"total disk loading time: {sum(total_disk_loading_time):.04f}")
+ print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}")
+ print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}")
+ else:
+ # loading is async/lazy, so can't really distinguish them...
+ print(f"total disk+gpu loading time: {sum(total_disk_loading_time) + sum(total_gpu_loading_time):.04f}")
+ print(f"total infer time(including all above plus gpu compute): {forward_elapsed_time:.04f}")
+
+ total_disk_loading_time = []
+ total_gpu_loading_time = []
+ total_compression_overhead_time = []
+
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=logits,
+ past_key_values=tuple(kv_cache_list) if kv_cache_list is not None else None,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states is not None else None,
+ attentions=tuple(all_self_attns) if all_hidden_states is not None else None,
+ )
\ No newline at end of file
diff --git a/air_llm/airllm/tokenization_baichuan.py b/air_llm/airllm/tokenization_baichuan.py
new file mode 100644
index 0000000..1d347e6
--- /dev/null
+++ b/air_llm/airllm/tokenization_baichuan.py
@@ -0,0 +1,251 @@
+# Copyright 2023 Baichuan Inc. All Rights Reserved.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {},
+ "tokenizer_file": {},
+}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
+
+
+class BaichuanTokenizer(PreTrainedTokenizer):
+ """
+ Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ pad_token=None,
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text):
+ """Returns a tokenized string."""
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for i, token in enumerate(tokens):
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special and i != 0:
+ out_string += " "
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string
+
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = bos_token_id + token_ids_0 + eos_token_id
+
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
+
+ return output
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
+
+ if token_ids_1 is None:
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (
+ bos_token_id
+ + ([0] * len(token_ids_0))
+ + eos_token_id
+ + bos_token_id
+ + ([0] * len(token_ids_1))
+ + eos_token_id
+ )
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+ sequence pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of ids.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+ if token_ids_1 is not None:
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+ return output
diff --git a/air_llm/setup.py b/air_llm/setup.py
index f8f1ab7..6e23503 100644
--- a/air_llm/setup.py
+++ b/air_llm/setup.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup(
name="airllm",
- version="2.3.1",
+ version="2.4.0",
author="Gavin Li",
author_email="gavinli@animaai.cloud",
description="AirLLM allows single 4GB GPU card to run 70B large language models without quantization, distillation or pruning.",