mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
fix auto model
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from transformers import AutoConfig
|
||||
|
||||
from .airllm import AirLLMLlama2
|
||||
@@ -15,20 +16,29 @@ class AutoModel:
|
||||
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
def get_module_class(cls, pretrained_model_name_or_path):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
||||
|
||||
if "QWen" in config.architectures[0]:
|
||||
return AirLLMQWen(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm_qwen", "AirLLMQWen"
|
||||
elif "Baichuan" in config.architectures[0]:
|
||||
return AirLLMBaichuan(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm_baichuan", "AirLLMBaichuan"
|
||||
elif "ChatGLM" in config.architectures[0]:
|
||||
return AirLLMChatGLM(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm_chatglm", "AirLLMChatGLM"
|
||||
elif "InternLM" in config.architectures[0]:
|
||||
return AirLLMInternLM(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm_internlm", "AirLLMInternLM"
|
||||
elif "Mistral" in config.architectures[0]:
|
||||
return AirLLMMistral(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm_mistral", "AirLLMMistral"
|
||||
elif "Llama" in config.architectures[0]:
|
||||
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm", "AirLLMLlama2"
|
||||
else:
|
||||
print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...")
|
||||
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
return ".airllm", "AirLLMLlama2"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
module, cls = AutoModel.get_module_class(pretrained_model_name_or_path)
|
||||
|
||||
module = importlib.import_module(module)
|
||||
class_ = getattr(module, cls)
|
||||
return class_(pretrained_model_name_or_path, *inputs, ** kwargs)
|
||||
@@ -25,7 +25,6 @@ class TestAutoModel(unittest.TestCase):
|
||||
|
||||
|
||||
for k,v in mapping_dict.items():
|
||||
model = AutoModel.from_pretrained(k)
|
||||
self.assertEqual(model.__class__.__name__, v, f"expecting {v}")
|
||||
del model
|
||||
module, cls = AutoModel.get_module_class(k)
|
||||
self.assertEqual(cls, v, f"expecting {v}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user