mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
make sure its on cuda
This commit is contained in:
@@ -29,6 +29,7 @@ except ImportError:
|
||||
|
||||
total_disk_loading_time = None
|
||||
total_gpu_loading_time = None
|
||||
total_compression_overhead_time = None
|
||||
|
||||
|
||||
# replacement for bnb quantstat.as_dict(True), until the bug is fixed....
|
||||
@@ -86,7 +87,7 @@ def uncompress_layer_state_dict(layer_state_dict):
|
||||
quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk}
|
||||
quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cuda")
|
||||
|
||||
dqv = bnb.functional.dequantize_nf4(v, quant_state)
|
||||
dqv = bnb.functional.dequantize_nf4(v.cuda(), quant_state)
|
||||
uncompressed_layer_state_dict[k] = dqv
|
||||
del layer_state_dict
|
||||
elif any(['8bit' in k for k in layer_state_dict.keys()]):
|
||||
@@ -97,7 +98,7 @@ def uncompress_layer_state_dict(layer_state_dict):
|
||||
absmax = layer_state_dict[k + ".8bit.absmax"]
|
||||
code = layer_state_dict[k + ".8bit.code"]
|
||||
|
||||
dqv = bnb.functional.dequantize_blockwise(v, bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16))
|
||||
dqv = bnb.functional.dequantize_blockwise(v.cuda(), bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16))
|
||||
uncompressed_layer_state_dict[k] = dqv
|
||||
del layer_state_dict
|
||||
|
||||
@@ -105,7 +106,18 @@ def uncompress_layer_state_dict(layer_state_dict):
|
||||
|
||||
def load_layer(local_path, layer_name):
|
||||
layer_state_dict = load_file(Path(local_path) / (layer_name + ".safetensors"), device="cpu")
|
||||
return uncompress_layer_state_dict(layer_state_dict)
|
||||
|
||||
if total_compression_overhead_time is not None:
|
||||
t = time.process_time()
|
||||
|
||||
to_return = uncompress_layer_state_dict(layer_state_dict)
|
||||
|
||||
clean_memory()
|
||||
|
||||
if total_compression_overhead_time is not None:
|
||||
elapsed_time = time.process_time() - t
|
||||
total_compression_overhead_time.append(elapsed_time)
|
||||
return to_return
|
||||
|
||||
|
||||
|
||||
@@ -115,7 +127,7 @@ def check_space(checkpoint_path, layer_shards_saving_path=None, compression=None
|
||||
total_shard_files_size_bytes += os.path.getsize(model_shard_file)
|
||||
|
||||
if compression == '4bit':
|
||||
total_shard_files_size_bytes = total_shard_files_size_bytes // 4
|
||||
total_shard_files_size_bytes = int(total_shard_files_size_bytes / 0.2813)
|
||||
elif compression == '8bit':
|
||||
total_shard_files_size_bytes = total_shard_files_size_bytes // 2
|
||||
|
||||
@@ -404,11 +416,12 @@ class AirLLMLlama2(GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
global total_disk_loading_time, total_gpu_loading_time
|
||||
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
|
||||
|
||||
if self.profiling_mode:
|
||||
total_disk_loading_time = []
|
||||
total_gpu_loading_time = []
|
||||
total_compression_overhead_time = []
|
||||
|
||||
# Reboot the model to make sure buffers are loaded and memory is clean
|
||||
del self.model
|
||||
@@ -550,6 +563,7 @@ class AirLLMLlama2(GenerationMixin):
|
||||
if self.profiling_mode:
|
||||
print(f"total disk loading time: {sum(total_disk_loading_time):.04f}")
|
||||
print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}")
|
||||
print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}")
|
||||
|
||||
total_disk_loading_time = []
|
||||
total_gpu_loading_time = []
|
||||
|
||||
Reference in New Issue
Block a user