From d1cec72e255ddf284267dd670652f320b841b9f9 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 18:21:39 -0600 Subject: [PATCH] make sure its on cuda --- air_llm/airllm/airllm.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 9f10469..74f7c16 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -29,6 +29,7 @@ except ImportError: total_disk_loading_time = None total_gpu_loading_time = None +total_compression_overhead_time = None # replacement for bnb quantstat.as_dict(True), until the bug is fixed.... @@ -86,7 +87,7 @@ def uncompress_layer_state_dict(layer_state_dict): quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk} quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cuda") - dqv = bnb.functional.dequantize_nf4(v, quant_state) + dqv = bnb.functional.dequantize_nf4(v.cuda(), quant_state) uncompressed_layer_state_dict[k] = dqv del layer_state_dict elif any(['8bit' in k for k in layer_state_dict.keys()]): @@ -97,7 +98,7 @@ def uncompress_layer_state_dict(layer_state_dict): absmax = layer_state_dict[k + ".8bit.absmax"] code = layer_state_dict[k + ".8bit.code"] - dqv = bnb.functional.dequantize_blockwise(v, bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16)) + dqv = bnb.functional.dequantize_blockwise(v.cuda(), bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16)) uncompressed_layer_state_dict[k] = dqv del layer_state_dict @@ -105,7 +106,18 @@ def uncompress_layer_state_dict(layer_state_dict): def load_layer(local_path, layer_name): layer_state_dict = load_file(Path(local_path) / (layer_name + ".safetensors"), device="cpu") - return uncompress_layer_state_dict(layer_state_dict) + + if total_compression_overhead_time is not None: + t = time.process_time() + + to_return = uncompress_layer_state_dict(layer_state_dict) + + clean_memory() + + if total_compression_overhead_time is not None: + elapsed_time = time.process_time() - t + total_compression_overhead_time.append(elapsed_time) + return to_return @@ -115,7 +127,7 @@ def check_space(checkpoint_path, layer_shards_saving_path=None, compression=None total_shard_files_size_bytes += os.path.getsize(model_shard_file) if compression == '4bit': - total_shard_files_size_bytes = total_shard_files_size_bytes // 4 + total_shard_files_size_bytes = int(total_shard_files_size_bytes / 0.2813) elif compression == '8bit': total_shard_files_size_bytes = total_shard_files_size_bytes // 2 @@ -404,11 +416,12 @@ class AirLLMLlama2(GenerationMixin): return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - global total_disk_loading_time, total_gpu_loading_time + global total_disk_loading_time, total_gpu_loading_time, total_compression_overhead_time if self.profiling_mode: total_disk_loading_time = [] total_gpu_loading_time = [] + total_compression_overhead_time = [] # Reboot the model to make sure buffers are loaded and memory is clean del self.model @@ -550,6 +563,7 @@ class AirLLMLlama2(GenerationMixin): if self.profiling_mode: print(f"total disk loading time: {sum(total_disk_loading_time):.04f}") print(f"total gpu loading time: {sum(total_gpu_loading_time):.04f}") + print(f"total compression overhead time: {sum(total_compression_overhead_time):.04f}") total_disk_loading_time = [] total_gpu_loading_time = []