mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 14:24:44 +00:00
fix automodel
This commit is contained in:
@@ -18,17 +18,17 @@ class AutoModel:
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
||||
if "QWen" in config.architectures[0]:
|
||||
return AirLLMQWen(*inputs, **kwargs)
|
||||
return AirLLMQWen(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "Baichuan" in config.architectures[0]:
|
||||
return AirLLMBaichuan(*inputs, **kwargs)
|
||||
return AirLLMBaichuan(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "ChatGLM" in config.architectures[0]:
|
||||
return AirLLMChatGLM(*inputs, **kwargs)
|
||||
return AirLLMChatGLM(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "InternLM" in config.architectures[0]:
|
||||
return AirLLMInternLM(*inputs, **kwargs)
|
||||
return AirLLMInternLM(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "Mistral" in config.architectures[0]:
|
||||
return AirLLMMistral(*inputs, **kwargs)
|
||||
return AirLLMMistral(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "Llama" in config.architectures[0]:
|
||||
return AirLLMLlama2(*inputs, **kwargs)
|
||||
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...")
|
||||
return AirLLMLlama2(*inputs, **kwargs)
|
||||
return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
Reference in New Issue
Block a user