This commit is contained in:
Yu Li
2023-12-19 13:57:57 -06:00
parent 7c13c0e402
commit b5cd55ff44

View File

@@ -17,18 +17,18 @@ class AutoModel:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if "QWen" in config['architectures']:
if "QWen" in config.architectures[0]:
return AirLLMQWen(*inputs, **kwargs)
elif "Baichuan" in config['architectures']:
elif "Baichuan" in config.architectures[0]:
return AirLLMBaichuan(*inputs, **kwargs)
elif "ChatGLM" in config['architectures']:
elif "ChatGLM" in config.architectures[0]:
return AirLLMChatGLM(*inputs, **kwargs)
elif "InternLM" in config['architectures']:
elif "InternLM" in config.architectures[0]:
return AirLLMInternLM(*inputs, **kwargs)
elif "Mistral" in config['architectures']:
elif "Mistral" in config.architectures[0]:
return AirLLMMistral(*inputs, **kwargs)
elif "Llama" in config['architectures']:
elif "Llama" in config.architectures[0]:
return AirLLMLlama2(*inputs, **kwargs)
else:
print(f"unknown artichitecture: {config['architectures']}, try to use Llama2...")
print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...")
return AirLLMLlama2(*inputs, **kwargs)