initial add auto model

This commit is contained in:
Yu Li
2023-12-19 11:53:15 -06:00
parent 2ac7d815d8
commit 922e1c09c5
2 changed files with 66 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
from transformers import AutoConfig
from .airllm import AirLLMLlama2
from .airllm_mistral import AirLLMMistral
from .airllm_baichuan import AirLLMBaichuan
from .airllm_internlm import AirLLMInternLM
from .airllm_chatglm import AirLLMChatGLM
from .airllm_qwen import AirLLMQWen
class AutoModel:
def __init__(self):
raise EnvironmentError(
"AutoModel is designed to be instantiated "
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if "QWen" in config['architectures']:
return AirLLMQWen(*inputs, **kwargs)
elif "Baichuan" in config['architectures']:
return AirLLMBaichuan(*inputs, **kwargs)
elif "ChatGLM" in config['architectures']:
return AirLLMChatGLM(*inputs, **kwargs)
elif "InternLM" in config['architectures']:
return AirLLMInternLM(*inputs, **kwargs)
elif "Mistral" in config['architectures']:
return AirLLMMistral(*inputs, **kwargs)
elif "Llama" in config['architectures']:
return AirLLMLlama2(*inputs, **kwargs)
else:
print(f"unknown artichitecture: {config['architectures']}, try to use Llama2...")
return AirLLMLlama2(*inputs, **kwargs)

View File

@@ -0,0 +1,32 @@
import sys
import unittest
import torch
sys.path.insert(0, '../airllm')
from auto_model import AutoModel
class TestAutoModel(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_auto_model_should_return_correct_model(self):
mapping_dict = {
'garage-bAInd/Platypus2-7B': 'AirLLMLlama2',
'Qwen/Qwen-7B': 'AirLLMQWen',
'internlm/internlm-chat-7b': 'AirLLMInternLM',
'THUDM/chatglm3-6b-base': 'AirLLMChatGLM',
'baichuan-inc/Baichuan2-7B-Base': 'AirLLMBaichuan',
'mistralai/Mistral-7B-Instruct-v0.1': 'AirLLMMistral'
}
for k,v in mapping_dict.items():
model = AutoModel.from_pretrained(k)
self.assertEqual(model.__class__.__name__, v, f"expecting {v}")
del model