support hf_token in auto model

This commit is contained in:
Yu Li
2023-12-20 11:50:34 -06:00
parent e76b3b9098
commit eeba6b0a93
2 changed files with 9 additions and 4 deletions

View File

@@ -16,8 +16,12 @@ class AutoModel:
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def get_module_class(cls, pretrained_model_name_or_path):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
def get_module_class(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if 'hf_token' in kwargs:
print(f"using hf_token")
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, token=kwargs['hf_token'])
else:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if "QWen" in config.architectures[0]:
return "airllm", "AirLLMQWen"
@@ -37,7 +41,8 @@ class AutoModel:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
module, cls = AutoModel.get_module_class(pretrained_model_name_or_path)
module, cls = AutoModel.get_module_class(pretrained_model_name_or_path, *inputs, **kwargs)
module = importlib.import_module(module)
class_ = getattr(module, cls)

View File

@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup(
name="airllm",
version="2.6.1",
version="2.6.2",
author="Gavin Li",
author_email="gavinli@animaai.cloud",
description="AirLLM allows single 4GB GPU card to run 70B large language models without quantization, distillation or pruning.",