This commit is contained in:
Yu Li
2023-12-01 16:39:16 -06:00
parent 4be015811d
commit cccb70b461

View File

@@ -84,7 +84,7 @@ def uncompress_layer_state_dict(layer_state_dict):
for k, v in layer_state_dict.items():
if '4bit' not in k:
quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk}
quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cpu")
quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cuda")
dqv = bnb.functional.dequantize_nf4(v, quant_state)
uncompressed_layer_state_dict[k] = dqv