diff --git a/bitsandbytes/backends/hpu.py b/bitsandbytes/backends/hpu.py index 2bc367078..a16b294b4 100644 --- a/bitsandbytes/backends/hpu.py +++ b/bitsandbytes/backends/hpu.py @@ -4,7 +4,7 @@ import torch from bitsandbytes.functional import get_4bit_type -from bitsandbytes.utils import QuantState +from bitsandbytes.utils import QuantState, reverse_4bit_compress_format from .base import Backend from .cpu_xpu_common import ( @@ -191,9 +191,13 @@ def dequantize_nf4_impl( HPU dequantization function for NF4 quantized tensors. """ assert_on_hpu([input, absmax]) + + if len(input.shape) == 2: + input = reverse_4bit_compress_format(input.squeeze()) + out_shape = (math.prod(quant_state.shape),) out_dq = torch.ops.hpu.dequantize_nf4( - input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype + input, absmax.to(quant_state.dtype), blocksize, out_shape=out_shape, out_dtype=quant_state.dtype ) output = out_dq.reshape(quant_state.shape).T return output