From d3511affe058dda86a08938345e6f93780805126 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 18:43:52 -0600 Subject: [PATCH] make sure its on cuda --- air_llm/airllm/airllm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index bd8aa36..5978e94 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -98,7 +98,11 @@ def uncompress_layer_state_dict(layer_state_dict): absmax = layer_state_dict[k + ".8bit.absmax"] code = layer_state_dict[k + ".8bit.code"] - dqv = bnb.functional.dequantize_blockwise(v.cuda(), bnb.functional.QuantState(absmax=absmax, code=code, blocksize=2048, dtype=torch.float16)) + dqv = bnb.functional.dequantize_blockwise(v.cuda(), + bnb.functional.QuantState(absmax=absmax.cuda(), + code=code.cuda(), + blocksize=2048, + dtype=torch.float16)) uncompressed_layer_state_dict[k] = dqv del layer_state_dict