diff --git a/air_llm/airllm/auto_model.py b/air_llm/airllm/auto_model.py index 86ec1cb..6d24680 100644 --- a/air_llm/airllm/auto_model.py +++ b/air_llm/airllm/auto_model.py @@ -1,3 +1,4 @@ +import importlib from transformers import AutoConfig from .airllm import AirLLMLlama2 @@ -15,20 +16,29 @@ class AutoModel: "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method." ) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def get_module_class(cls, pretrained_model_name_or_path): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + if "QWen" in config.architectures[0]: - return AirLLMQWen(pretrained_model_name_or_path, *inputs, **kwargs) + return ".airllm_qwen", "AirLLMQWen" elif "Baichuan" in config.architectures[0]: - return AirLLMBaichuan(pretrained_model_name_or_path, *inputs, **kwargs) + return ".airllm_baichuan", "AirLLMBaichuan" elif "ChatGLM" in config.architectures[0]: - return AirLLMChatGLM(pretrained_model_name_or_path, *inputs, **kwargs) + return ".airllm_chatglm", "AirLLMChatGLM" elif "InternLM" in config.architectures[0]: - return AirLLMInternLM(pretrained_model_name_or_path, *inputs, **kwargs) + return ".airllm_internlm", "AirLLMInternLM" elif "Mistral" in config.architectures[0]: - return AirLLMMistral(pretrained_model_name_or_path, *inputs, **kwargs) + return ".airllm_mistral", "AirLLMMistral" elif "Llama" in config.architectures[0]: - return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs) + return ".airllm", "AirLLMLlama2" else: print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...") - return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs) \ No newline at end of file + return ".airllm", "AirLLMLlama2" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + module, cls = AutoModel.get_module_class(pretrained_model_name_or_path) + + module = importlib.import_module(module) + class_ = getattr(module, cls) + return class_(pretrained_model_name_or_path, *inputs, ** kwargs) \ No newline at end of file diff --git a/air_llm/tests/test_automodel.py b/air_llm/tests/test_automodel.py index 1b71148..01c2f91 100644 --- a/air_llm/tests/test_automodel.py +++ b/air_llm/tests/test_automodel.py @@ -25,7 +25,6 @@ class TestAutoModel(unittest.TestCase): for k,v in mapping_dict.items(): - model = AutoModel.from_pretrained(k) - self.assertEqual(model.__class__.__name__, v, f"expecting {v}") - del model + module, cls = AutoModel.get_module_class(k) + self.assertEqual(cls, v, f"expecting {v}")