From a822c3d307c14b44b6e2ae1640376babd7a216de Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Sun, 18 May 2025 21:16:32 +0300 Subject: [PATCH] supports NF4 checkpoint loading --- bitsandbytes/backends/hpu.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/hpu.py b/bitsandbytes/backends/hpu.py index 2bc367078..a16b294b4 100644 --- a/bitsandbytes/backends/hpu.py +++ b/bitsandbytes/backends/hpu.py @@ -4,7 +4,7 @@ import torch from bitsandbytes.functional import get_4bit_type -from bitsandbytes.utils import QuantState +from bitsandbytes.utils import QuantState, reverse_4bit_compress_format from .base import Backend from .cpu_xpu_common import ( @@ -191,9 +191,13 @@ def dequantize_nf4_impl( HPU dequantization function for NF4 quantized tensors. """ assert_on_hpu([input, absmax]) + + if len(input.shape) == 2: + input = reverse_4bit_compress_format(input.squeeze()) + out_shape = (math.prod(quant_state.shape),) out_dq = torch.ops.hpu.dequantize_nf4( - input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype + input, absmax.to(quant_state.dtype), blocksize, out_shape=out_shape, out_dtype=quant_state.dtype ) output = out_dq.reshape(quant_state.shape).T return output