From 922e1c09c562dbb98f88cca64b25391dc05bdf93 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Tue, 19 Dec 2023 11:53:15 -0600 Subject: [PATCH] initial add auto model --- air_llm/airllm/auto_model.py | 34 +++++++++++++++++++++++++++++++++ air_llm/tests/test_automodel.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 air_llm/airllm/auto_model.py create mode 100644 air_llm/tests/test_automodel.py diff --git a/air_llm/airllm/auto_model.py b/air_llm/airllm/auto_model.py new file mode 100644 index 0000000..e91744c --- /dev/null +++ b/air_llm/airllm/auto_model.py @@ -0,0 +1,34 @@ +from transformers import AutoConfig + +from .airllm import AirLLMLlama2 +from .airllm_mistral import AirLLMMistral +from .airllm_baichuan import AirLLMBaichuan +from .airllm_internlm import AirLLMInternLM +from .airllm_chatglm import AirLLMChatGLM +from .airllm_qwen import AirLLMQWen + + +class AutoModel: + def __init__(self): + raise EnvironmentError( + "AutoModel is designed to be instantiated " + "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method." + ) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + if "QWen" in config['architectures']: + return AirLLMQWen(*inputs, **kwargs) + elif "Baichuan" in config['architectures']: + return AirLLMBaichuan(*inputs, **kwargs) + elif "ChatGLM" in config['architectures']: + return AirLLMChatGLM(*inputs, **kwargs) + elif "InternLM" in config['architectures']: + return AirLLMInternLM(*inputs, **kwargs) + elif "Mistral" in config['architectures']: + return AirLLMMistral(*inputs, **kwargs) + elif "Llama" in config['architectures']: + return AirLLMLlama2(*inputs, **kwargs) + else: + print(f"unknown artichitecture: {config['architectures']}, try to use Llama2...") + return AirLLMLlama2(*inputs, **kwargs) \ No newline at end of file diff --git a/air_llm/tests/test_automodel.py b/air_llm/tests/test_automodel.py new file mode 100644 index 0000000..8224295 --- /dev/null +++ b/air_llm/tests/test_automodel.py @@ -0,0 +1,32 @@ +import sys +import unittest + +import torch +sys.path.insert(0, '../airllm') + +from auto_model import AutoModel + + + +class TestAutoModel(unittest.TestCase): + def setUp(self): + pass + def tearDown(self): + pass + + def test_auto_model_should_return_correct_model(self): + mapping_dict = { + 'garage-bAInd/Platypus2-7B': 'AirLLMLlama2', + 'Qwen/Qwen-7B': 'AirLLMQWen', + 'internlm/internlm-chat-7b': 'AirLLMInternLM', + 'THUDM/chatglm3-6b-base': 'AirLLMChatGLM', + 'baichuan-inc/Baichuan2-7B-Base': 'AirLLMBaichuan', + 'mistralai/Mistral-7B-Instruct-v0.1': 'AirLLMMistral' + } + + + for k,v in mapping_dict.items(): + model = AutoModel.from_pretrained(k) + self.assertEqual(model.__class__.__name__, v, f"expecting {v}") + del model +