mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
fix test
This commit is contained in:
@@ -17,18 +17,18 @@ class AutoModel:
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
||||
if "QWen" in config['architectures']:
|
||||
if "QWen" in config.architectures[0]:
|
||||
return AirLLMQWen(*inputs, **kwargs)
|
||||
elif "Baichuan" in config['architectures']:
|
||||
elif "Baichuan" in config.architectures[0]:
|
||||
return AirLLMBaichuan(*inputs, **kwargs)
|
||||
elif "ChatGLM" in config['architectures']:
|
||||
elif "ChatGLM" in config.architectures[0]:
|
||||
return AirLLMChatGLM(*inputs, **kwargs)
|
||||
elif "InternLM" in config['architectures']:
|
||||
elif "InternLM" in config.architectures[0]:
|
||||
return AirLLMInternLM(*inputs, **kwargs)
|
||||
elif "Mistral" in config['architectures']:
|
||||
elif "Mistral" in config.architectures[0]:
|
||||
return AirLLMMistral(*inputs, **kwargs)
|
||||
elif "Llama" in config['architectures']:
|
||||
elif "Llama" in config.architectures[0]:
|
||||
return AirLLMLlama2(*inputs, **kwargs)
|
||||
else:
|
||||
print(f"unknown artichitecture: {config['architectures']}, try to use Llama2...")
|
||||
print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...")
|
||||
return AirLLMLlama2(*inputs, **kwargs)
|
||||
Reference in New Issue
Block a user