diff --git a/air_llm/airllm/auto_model.py b/air_llm/airllm/auto_model.py index c6c4bc7..86ec1cb 100644 --- a/air_llm/airllm/auto_model.py +++ b/air_llm/airllm/auto_model.py @@ -18,17 +18,17 @@ class AutoModel: def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) if "QWen" in config.architectures[0]: - return AirLLMQWen(*inputs, **kwargs) + return AirLLMQWen(pretrained_model_name_or_path, *inputs, **kwargs) elif "Baichuan" in config.architectures[0]: - return AirLLMBaichuan(*inputs, **kwargs) + return AirLLMBaichuan(pretrained_model_name_or_path, *inputs, **kwargs) elif "ChatGLM" in config.architectures[0]: - return AirLLMChatGLM(*inputs, **kwargs) + return AirLLMChatGLM(pretrained_model_name_or_path, *inputs, **kwargs) elif "InternLM" in config.architectures[0]: - return AirLLMInternLM(*inputs, **kwargs) + return AirLLMInternLM(pretrained_model_name_or_path, *inputs, **kwargs) elif "Mistral" in config.architectures[0]: - return AirLLMMistral(*inputs, **kwargs) + return AirLLMMistral(pretrained_model_name_or_path, *inputs, **kwargs) elif "Llama" in config.architectures[0]: - return AirLLMLlama2(*inputs, **kwargs) + return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs) else: print(f"unknown artichitecture: {config.architectures[0]}, try to use Llama2...") - return AirLLMLlama2(*inputs, **kwargs) \ No newline at end of file + return AirLLMLlama2(pretrained_model_name_or_path, *inputs, **kwargs) \ No newline at end of file