From cccb70b4616697814985bb6069d8d0592329ce50 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Dec 2023 16:39:16 -0600 Subject: [PATCH] fix ut --- air_llm/airllm/airllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 94ebf15..3413944 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -84,7 +84,7 @@ def uncompress_layer_state_dict(layer_state_dict): for k, v in layer_state_dict.items(): if '4bit' not in k: quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk} - quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cpu") + quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cuda") dqv = bnb.functional.dequantize_nf4(v, quant_state) uncompressed_layer_state_dict[k] = dqv