mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 22:33:47 +00:00
make sure its on cuda
This commit is contained in:
@@ -98,7 +98,11 @@ def uncompress_layer_state_dict(layer_state_dict):
|
||||
absmax = layer_state_dict[k + ".8bit.absmax"]
|
||||
code = layer_state_dict[k + ".8bit.code"]
|
||||
|
||||
dqv = bnb.functional.dequantize_blockwise(v.cuda(), bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16))
|
||||
dqv = bnb.functional.dequantize_blockwise(v.cuda(),
|
||||
bnb.functional.QuantState(absmax=absmax.cuda(),
|
||||
code=code.cuda(),
|
||||
blocksize=2048,
|
||||
dtype=torch.float16))
|
||||
uncompressed_layer_state_dict[k] = dqv
|
||||
del layer_state_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user