make sure its on cuda

This commit is contained in:
Yu Li
2023-12-01 18:21:39 -06:00
parent 290cfa1122
commit d1cec72e25

View File

@@ -29,6 +29,7 @@ except ImportError:
total_disk_loading_time = None
total_gpu_loading_time = None
total_compression_overhead_time = None
# replacement for bnb quantstat.as_dict(True), until the bug is fixed....
@@ -86,7 +87,7 @@ def uncompress_layer_state_dict(layer_state_dict):
quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk}
quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cuda")
dqv = bnb.functional.dequantize_nf4(v, quant_state)
dqv = bnb.functional.dequantize_nf4(v.cuda(), quant_state)
uncompressed_layer_state_dict[k] = dqv
del layer_state_dict
elif any(['8bit' in k for k in layer_state_dict.keys()]):
@@ -97,7 +98,7 @@ def uncompress_layer_state_dict(layer_state_dict):
absmax = layer_state_dict[k + ".8bit.absmax"]
code = layer_state_dict[k + ".8bit.code"]
dqv = bnb.functional.dequantize_blockwise(v, bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16))
dqv = bnb.functional.dequantize_blockwise(v.cuda(), bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16))
uncompressed_layer_state_dict[k] = dqv
del layer_state_dict
@@ -105,7 +106,18 @@ def uncompress_layer_state_dict(layer_state_dict):
def load_layer(local_path, layer_name):
layer_state_dict = load_file(Path(local_path) / (layer_name + ".safetensors"), device="cpu")
return uncompress_layer_state_dict(layer_state_dict)
if total_compression_overhead_time is not None:
t = time.process_time()
to_return = uncompress_layer_state_dict(layer_state_dict)
clean_memory()
if total_compression_overhead_time is not None:
elapsed_time = time.process_time() - t
total_compression_overhead_time.append(elapsed_time)
return to_return
@@ -115,7 +127,7 @@ def check_space(checkpoint_path, layer_shards_saving_path=None, compression=None
total_shard_files_size_bytes += os.path.getsize(model_shard_file)
if compression == '4bit':
total_shard_files_size_bytes = total_shard_files_size_bytes // 4
total_shard_files_size_bytes = int(total_shard_files_size_bytes / 0.2813)
elif compression == '8bit':
total_shard_files_size_bytes = total_shard_files_size_bytes // 2
@@ -404,11 +416,12 @@ class AirLLMLlama2(GenerationMixin):
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
global total_disk_loading_time, total_gpu_loading_time
global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time
if self.profiling_mode:
total_disk_loading_time = []
total_gpu_loading_time = []
total_compression_overhead_time = []
# Reboot the model to make sure buffers are loaded and memory is clean
del self.model
@@ -550,6 +563,7 @@ class AirLLMLlama2(GenerationMixin):
if self.profiling_mode:
print(f"total disk loading time: {sum(total_disk_loading_time):.04f}")
print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}")
print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}")
total_disk_loading_time = []
total_gpu_loading_time = []