2048 block size for 8bit

This commit is contained in:
Yu Li
2023-12-01 17:12:04 -06:00
parent 6ad5ae82ff
commit f3d9067c5c

View File

@@ -97,7 +97,7 @@ def uncompress_layer_state_dict(layer_state_dict):
absmax = layer_state_dict[k + ".8bit.absmax"]
code = layer_state_dict[k + ".8bit.code"]
dqv = bnb.functional.dequantize(v, (absmax, code))
dqv = bnb.functional.dequantize_blockwise(v, (absmax, code), blocksize=2048)
uncompressed_layer_state_dict[k] = dqv
del layer_state_dict
@@ -132,7 +132,7 @@ def compress_layer_state_dict(layer_state_dict, compression=None):
elif compression == '8bit':
compressed_layer_state_dict = {}
for k, v in layer_state_dict.items():
v_quant, (absmax, code) = bnb.functional.quantize(v)
v_quant, (absmax, code) = bnb.functional.quantize_blockwise(v, blocksize=2048)
compressed_layer_state_dict[k] = v_quant
compressed_layer_state_dict[k + ".8bit.absmax"] = absmax
compressed_layer_state_dict[k + ".8bit.code"] = code