fix automodel

This commit is contained in:
Yu Li
2023-12-19 14:02:46 -06:00
parent b5cd55ff44
commit 22b20d07d4

View File

@@ -18,17 +18,17 @@ class AutoModel:
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if "QWen" in config.architectures[0]:
return AirLLMQWen(*inputs, **kwargs)
return AirLLMQWen(pretrained_model_name_or_path, *inputs, **kwargs)
elif "Baichuan" in config.architectures[0]:
return AirLLMBaichuan(*inputs, **kwargs)
return AirLLMBaichuan(pretrained_model_name_or_path, *inputs, **kwargs)
elif "ChatGLM" in config.architectures[0]:
return AirLLMChatGLM(*inputs, **kwargs)
return AirLLMChatGLM(pretrained_model_name_or_path, *inputs, **kwargs)
elif "InternLM" in config.architectures[0]:
return AirLLMInternLM(*inputs, **kwargs)
return AirLLMInternLM(pretrained_model_name_or_path, *inputs, **kwargs)
elif "Mistral" in config.architectures[0]:
return AirLLMMistral(*inputs, **kwargs)
return AirLLMMistral(pretrained_model_name_or_path, *inputs, **kwargs)
elif "Llama" in config.architectures[0]:
return AirLLMLlama2(*inputs, **kwargs)
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
else:
print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...")
return AirLLMLlama2(*inputs, **kwargs)
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)