mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 14:24:44 +00:00
Fix: fixed error when running on cpu, fixed setup.py to read README.md as utf-8 encoding and added post install command to upgrade transformers to avoid rope_scaling error
This commit is contained in:
@@ -153,7 +153,8 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
self.prefetching = False
|
||||
print(f"not support prefetching for compression for now. loading with no prepetching mode.")
|
||||
|
||||
if prefetching:
|
||||
# this operation should run only if gpu is available
|
||||
if prefetching and device.startswith("cuda"):
|
||||
self.stream = torch.cuda.Stream()
|
||||
else:
|
||||
self.stream = None
|
||||
@@ -285,8 +286,12 @@ class AirLLMBaseModel(GenerationMixin):
|
||||
# pin memory:
|
||||
if self.prefetching:
|
||||
t = time.time()
|
||||
for k in state_dict.keys():
|
||||
state_dict[k].pin_memory()
|
||||
if torch.cuda.is_available(): # Check if CUDA is available
|
||||
for k in state_dict.keys():
|
||||
state_dict[k].pin_memory()
|
||||
else:
|
||||
# For CPU, no action is needed, but you could optionally add a log or message
|
||||
print("Prefetching is enabled, but no pin_memory operation is needed for CPU.")
|
||||
|
||||
elapsed_time = time.time() - t
|
||||
if self.profiling_mode:
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
import setuptools
|
||||
from setuptools.command.install import install
|
||||
import subprocess
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
# upgrade transformers to latest version to avoid "`rope_scaling` must be a dictionary with two fields" error
|
||||
class PostInstallCommand(install):
|
||||
def run(self):
|
||||
install.run(self)
|
||||
subprocess.check_call(["pip", "install", "--upgrade", "transformers"])
|
||||
|
||||
# Windows uses a different default encoding (use a consistent encoding)
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setuptools.setup(
|
||||
@@ -24,6 +33,9 @@ setuptools.setup(
|
||||
'scipy',
|
||||
#'bitsandbytes' set it to optional to support fallback when not installable
|
||||
],
|
||||
cmdclass={
|
||||
'install': PostInstallCommand,
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
|
||||
Reference in New Issue
Block a user