fix auto model

This commit is contained in:
Yu Li
2023-12-19 14:23:38 -06:00
parent 22b20d07d4
commit 67dbeed719
2 changed files with 20 additions and 11 deletions

View File

@@ -1,3 +1,4 @@
import importlib
from transformers import AutoConfig
from .airllm import AirLLMLlama2
@@ -15,20 +16,29 @@ class AutoModel:
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
def get_module_class(cls, pretrained_model_name_or_path):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if "QWen" in config.architectures[0]:
return AirLLMQWen(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm_qwen", "AirLLMQWen"
elif "Baichuan" in config.architectures[0]:
return AirLLMBaichuan(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm_baichuan", "AirLLMBaichuan"
elif "ChatGLM" in config.architectures[0]:
return AirLLMChatGLM(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm_chatglm", "AirLLMChatGLM"
elif "InternLM" in config.architectures[0]:
return AirLLMInternLM(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm_internlm", "AirLLMInternLM"
elif "Mistral" in config.architectures[0]:
return AirLLMMistral(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm_mistral", "AirLLMMistral"
elif "Llama" in config.architectures[0]:
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm", "AirLLMLlama2"
else:
print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...")
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
return ".airllm", "AirLLMLlama2"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
module, cls = AutoModel.get_module_class(pretrained_model_name_or_path)
module = importlib.import_module(module)
class_ = getattr(module, cls)
return class_(pretrained_model_name_or_path, *inputs, ** kwargs)

View File

@@ -25,7 +25,6 @@ class TestAutoModel(unittest.TestCase):
for k,v in mapping_dict.items():
model = AutoModel.from_pretrained(k)
self.assertEqual(model.__class__.__name__, v, f"expecting {v}")
del model
module, cls = AutoModel.get_module_class(k)
self.assertEqual(cls, v, f"expecting {v}")