From a451187e5073e5f03a5e9bac4c4187c1a697812e Mon Sep 17 00:00:00 2001 From: Navodplayer1 Date: Sun, 11 Aug 2024 20:01:34 +0530 Subject: [PATCH] Fix: fixed error when running on cpu, fixed setup.py to read README.md as utf-8 encoding and added post install command to upgrade transformers to avoid rope_scaling error --- air_llm/airllm/airllm_base.py | 11 ++++++++--- air_llm/setup.py | 14 +++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/air_llm/airllm/airllm_base.py b/air_llm/airllm/airllm_base.py index b214cb1..d8e2835 100644 --- a/air_llm/airllm/airllm_base.py +++ b/air_llm/airllm/airllm_base.py @@ -153,7 +153,8 @@ class AirLLMBaseModel(GenerationMixin): self.prefetching = False print(f"not support prefetching for compression for now. loading with no prepetching mode.") - if prefetching: + # this operation should run only if gpu is available + if prefetching and device.startswith("cuda"): self.stream = torch.cuda.Stream() else: self.stream = None @@ -285,8 +286,12 @@ class AirLLMBaseModel(GenerationMixin): # pin memory: if self.prefetching: t = time.time() - for k in state_dict.keys(): - state_dict[k].pin_memory() + if torch.cuda.is_available(): # Check if CUDA is available + for k in state_dict.keys(): + state_dict[k].pin_memory() + else: + # For CPU, no action is needed, but you could optionally add a log or message + print("Prefetching is enabled, but no pin_memory operation is needed for CPU.") elapsed_time = time.time() - t if self.profiling_mode: diff --git a/air_llm/setup.py b/air_llm/setup.py index 697caff..37e3fd9 100644 --- a/air_llm/setup.py +++ b/air_llm/setup.py @@ -1,6 +1,15 @@ import setuptools +from setuptools.command.install import install +import subprocess -with open("README.md", "r") as fh: +# upgrade transformers to latest version to avoid "`rope_scaling` must be a dictionary with two fields" error +class PostInstallCommand(install): + def run(self): + install.run(self) + subprocess.check_call(["pip", "install", "--upgrade", "transformers"]) + +# Windows uses a different default encoding (use a consistent encoding) +with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setuptools.setup( @@ -24,6 +33,9 @@ setuptools.setup( 'scipy', #'bitsandbytes' set it to optional to support fallback when not installable ], + cmdclass={ + 'install': PostInstallCommand, + }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",