mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
initial add auto model
This commit is contained in:
34
air_llm/airllm/auto_model.py
Normal file
34
air_llm/airllm/auto_model.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from transformers import AutoConfig
|
||||
|
||||
from .airllm import AirLLMLlama2
|
||||
from .airllm_mistral import AirLLMMistral
|
||||
from .airllm_baichuan import AirLLMBaichuan
|
||||
from .airllm_internlm import AirLLMInternLM
|
||||
from .airllm_chatglm import AirLLMChatGLM
|
||||
from .airllm_qwen import AirLLMQWen
|
||||
|
||||
|
||||
class AutoModel:
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoModel is designed to be instantiated "
|
||||
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
||||
if "QWen" in config['architectures']:
|
||||
return AirLLMQWen(*inputs, **kwargs)
|
||||
elif "Baichuan" in config['architectures']:
|
||||
return AirLLMBaichuan(*inputs, **kwargs)
|
||||
elif "ChatGLM" in config['architectures']:
|
||||
return AirLLMChatGLM(*inputs, **kwargs)
|
||||
elif "InternLM" in config['architectures']:
|
||||
return AirLLMInternLM(*inputs, **kwargs)
|
||||
elif "Mistral" in config['architectures']:
|
||||
return AirLLMMistral(*inputs, **kwargs)
|
||||
elif "Llama" in config['architectures']:
|
||||
return AirLLMLlama2(*inputs, **kwargs)
|
||||
else:
|
||||
print(f"unknown artichitecture: {config['architectures']}, try to use Llama2...")
|
||||
return AirLLMLlama2(*inputs, **kwargs)
|
||||
32
air_llm/tests/test_automodel.py
Normal file
32
air_llm/tests/test_automodel.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
sys.path.insert(0, '../airllm')
|
||||
|
||||
from auto_model import AutoModel
|
||||
|
||||
|
||||
|
||||
class TestAutoModel(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_auto_model_should_return_correct_model(self):
|
||||
mapping_dict = {
|
||||
'garage-bAInd/Platypus2-7B': 'AirLLMLlama2',
|
||||
'Qwen/Qwen-7B': 'AirLLMQWen',
|
||||
'internlm/internlm-chat-7b': 'AirLLMInternLM',
|
||||
'THUDM/chatglm3-6b-base': 'AirLLMChatGLM',
|
||||
'baichuan-inc/Baichuan2-7B-Base': 'AirLLMBaichuan',
|
||||
'mistralai/Mistral-7B-Instruct-v0.1': 'AirLLMMistral'
|
||||
}
|
||||
|
||||
|
||||
for k,v in mapping_dict.items():
|
||||
model = AutoModel.from_pretrained(k)
|
||||
self.assertEqual(model.__class__.__name__, v, f"expecting {v}")
|
||||
del model
|
||||
|
||||
Reference in New Issue
Block a user