make sure its on cuda

This commit is contained in:
Yu Li
2023-12-01 18:43:52 -06:00
parent 9a5d4dc730
commit d3511affe0

View File

@@ -98,7 +98,11 @@ def uncompress_layer_state_dict(layer_state_dict):
absmax = layer_state_dict[k + ".8bit.absmax"]
code = layer_state_dict[k + ".8bit.code"]
dqv = bnb.functional.dequantize_blockwise(v.cuda(), bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16))
dqv = bnb.functional.dequantize_blockwise(v.cuda(),
bnb.functional.QuantState(absmax=absmax.cuda(),
code=code.cuda(),
blocksize=2048,
dtype=torch.float16))
uncompressed_layer_state_dict[k] = dqv
del layer_state_dict