From f3d9067c5cde3c47923538318b8a9ed7d23c8c19 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 17:12:04 -0600 Subject: [PATCH] 2048 block size for 8bit --- air_llm/airllm/airllm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 3413944..4257c8c 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -97,7 +97,7 @@ def uncompress_layer_state_dict(layer_state_dict): absmax = layer_state_dict[k + ".8bit.absmax"] code = layer_state_dict[k + ".8bit.code"] - dqv = bnb.functional.dequantize(v, (absmax, code)) + dqv = bnb.functional.dequantize_blockwise(v, (absmax, code), blocksize=2048) uncompressed_layer_state_dict[k] = dqv del layer_state_dict @@ -132,7 +132,7 @@ def compress_layer_state_dict(layer_state_dict, compression=None): elif compression == '8bit': compressed_layer_state_dict = {} for k, v in layer_state_dict.items(): - v_quant, (absmax, code) = bnb.functional.quantize(v) + v_quant, (absmax, code) = bnb.functional.quantize_blockwise(v, blocksize=2048) compressed_layer_state_dict[k] = v_quant compressed_layer_state_dict[k + ".8bit.absmax"] = absmax compressed_layer_state_dict[k + ".8bit.code"] = code