Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bitsandbytes/backends/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from bitsandbytes.functional import get_4bit_type
from bitsandbytes.utils import QuantState
from bitsandbytes.utils import QuantState, reverse_4bit_compress_format

from .base import Backend
from .cpu_xpu_common import (
Expand Down Expand Up @@ -191,9 +191,13 @@ def dequantize_nf4_impl(
HPU dequantization function for NF4 quantized tensors.
"""
assert_on_hpu([input, absmax])

if len(input.shape) == 2:
input = reverse_4bit_compress_format(input.squeeze())

out_shape = (math.prod(quant_state.shape),)
out_dq = torch.ops.hpu.dequantize_nf4(
input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype
input, absmax.to(quant_state.dtype), blocksize, out_shape=out_shape, out_dtype=quant_state.dtype
)
output = out_dq.reshape(quant_state.shape).T
return output
Expand Down
Loading