diff --git a/air_llm/airllm/airllm.py b/air_llm/airllm/airllm.py index 94ebf15..3413944 100644 --- a/air_llm/airllm/airllm.py +++ b/air_llm/airllm/airllm.py @@ -84,7 +84,7 @@ def uncompress_layer_state_dict(layer_state_dict): for k, v in layer_state_dict.items(): if '4bit' not in k: quant_state_dict = {kk[len(k):]: kv for kk, kv in layer_state_dict.items() if kk.startswith(k) and k != kk} - quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cpu") + quant_state = bnb.functional.QuantState.from_dict(qs_dict=quant_state_dict, device="cuda") dqv = bnb.functional.dequantize_nf4(v, quant_state) uncompressed_layer_state_dict[k] = dqv