diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index f850140a1..59f881cc9 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -28,11 +28,18 @@ "npu", # Ascend NPU "xpu", # Intel GPU "cpu", + "hpu", # Intel Gaudi } # Always register the CPU backend. register_backend("cpu", CPUBackend()) +# Register HPU Backend, if available +if hasattr(torch, "hpu") and torch.hpu.is_available(): + from .backends.hpu import HPUBackend + + register_backend("hpu", HPUBackend()) + # Register either CUDA or ROCm backend, if available. # Only one of these backends can be used at a time, since the torch.device semantics are # the same for both torch+rocm and torch+cuda (e.g. device name is "cuda") diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d224cfe1c..b14d2024c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -570,7 +570,7 @@ def matmul_4bit( return out else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu": + elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type not in ("npu", "hpu"): if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 3d99398fc..afe71c080 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -35,6 +35,9 @@ class CPUBackend(Backend): mm_dequant_compute_dtype = torch.bfloat16 mm_dequant_output_dtype = torch.bfloat16 + def device_synchronize(self): + pass + def int8_double_quant( self, A: torch.Tensor, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 87ffc7360..22e2563d9 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -60,7 +60,7 @@ def _ipex_xpu_version_prereq(major, minor): def _maybe_torch_compile(func): # torch.compile requires g++ and pytorch >= 2.0 - if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: + if gxx_available and _torch_version_prereq(2, 0) and ipex_cpu_only: options = {} # fx_graph_cache requires pytorch >= 2.2 if _torch_version_prereq(2, 2): @@ -369,8 +369,9 @@ def quantize_4bit_impl( out_uint8[abs_scaled_A > key] = val out_uint8 += sign.to(torch.uint8) * 8 elif quant_type == "int8": - for i in range(len(INT8_QUANT_TABLE)): - out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i + map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device) + diff = torch.abs(scaled_A.unsqueeze(-1) - map) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) if quant_type == "int8": out = out_uint8 diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index f8c27255f..a3a610580 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -97,6 +97,9 @@ class CUDABackend(Backend): + def device_synchronize(self): + torch.cuda.synchronize() + def transform( self, A: torch.Tensor, diff --git a/bitsandbytes/backends/hpu.py b/bitsandbytes/backends/hpu.py new file mode 100644 index 000000000..2bc367078 --- /dev/null +++ b/bitsandbytes/backends/hpu.py @@ -0,0 +1,298 @@ +import math +from typing import Literal, Optional, Tuple + +import torch + +from bitsandbytes.functional import get_4bit_type +from bitsandbytes.utils import QuantState + +from .base import Backend +from .cpu_xpu_common import ( + INT8_QUANT_TABLE, + NF4_QUANT_TABLE, + dequant_8bit, +) + +Tensor = torch.Tensor + + +def assert_on_hpu(tensors): + on_hpu = True + for t in tensors: + if t is None: + continue # NULL pointers are fine + on_hpu &= t.device.type == "hpu" + if not on_hpu: + raise TypeError( + "All input tensors need to be on HPU, but found some tensors to not be on HPU:\n" + f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}" + ) + return on_hpu + + +class HPUBackend(Backend): + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError("Not yet implemented for HPU backend") + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["nf4"] = "nf4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if blocksize is None: + blocksize = 64 + assert_on_hpu([A, absmax, out]) + assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage" + return self.quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) + + def quantize_4bit_impl( + self, + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + ) -> Tensor: + if quant_type not in ["nf4", "int8"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for HPU.") + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + n = A.numel() + input_shape = A.shape + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + + if absmax is None: + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + # map [-1, 1] to nf4 + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device) + if quant_type == "nf4": + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + elif quant_type == "int8": + map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device) + diff = torch.abs(scaled_A.unsqueeze(-1) - map) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + + if quant_type == "int8": + out = out_uint8 + code = torch.Tensor(INT8_QUANT_TABLE).to(A.device) + else: + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + # To align with HPU dequantize operator + out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = self.quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + return out, state + + def dequantize_nf4_impl( + self, + input: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_state: QuantState, + ) -> torch.Tensor: + """ + HPU dequantization function for NF4 quantized tensors. + """ + assert_on_hpu([input, absmax]) + out_shape = (math.prod(quant_state.shape),) + out_dq = torch.ops.hpu.dequantize_nf4( + input, absmax, blocksize, out_shape=out_shape, out_dtype=quant_state.dtype + ) + output = out_dq.reshape(quant_state.shape).T + return output + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["nf4"] = "nf4", + ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 + + assert_on_hpu([A, absmax, out]) + if quant_state.nested: + absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) + return self.dequantize_nf4_impl(A, absmax, blocksize, quant_state) + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor): + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_vectorwise_quant(self, A: torch.Tensor, threshold=0.0): + raise NotImplementedError("Not yet implemented for HPU backend") + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError("Not yet implemented for HPU backend") + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError("Not yet implemented for HPU backend") + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError("Not yet implemented for HPU backend") diff --git a/bitsandbytes/backends/mps.py b/bitsandbytes/backends/mps.py index 5b7eda0c7..9400699a9 100644 --- a/bitsandbytes/backends/mps.py +++ b/bitsandbytes/backends/mps.py @@ -8,6 +8,9 @@ class MPSBackend(Backend): + def device_synchronize(self): + torch.mps.synchronize() + def double_quant( self, A: torch.Tensor, diff --git a/bitsandbytes/backends/npu.py b/bitsandbytes/backends/npu.py index d22fe04e8..cd3933879 100644 --- a/bitsandbytes/backends/npu.py +++ b/bitsandbytes/backends/npu.py @@ -29,6 +29,9 @@ def assert_on_npu(tensors): class NPUBackend(Backend): + def device_synchronize(self): + torch.npu.synchronize() + def int8_double_quant( self, A: torch.Tensor, diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py index c1c20aa1e..702c3c386 100644 --- a/bitsandbytes/backends/xpu.py +++ b/bitsandbytes/backends/xpu.py @@ -12,11 +12,28 @@ int8_linear_matmul_impl, int8_mm_dequant_impl, quantize_4bit_impl, + _ipex_xpu_version_prereq ) +try: + import intel_extension_for_pytorch as ipex + ipex_xpu = ipex if ipex._C._has_xpu() else None +except BaseException: + ipex_xpu = None Tensor = torch.Tensor +str2optimizer8bit_blockwise = {} +if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): + str2optimizer8bit_blockwise = { + "adam": ( + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, + ), + } + + def assert_on_xpu(tensors): on_xpu = True for t in tensors: @@ -35,6 +52,9 @@ class XPUBackend(Backend): mm_dequant_compute_dtype = torch.bfloat16 mm_dequant_output_dtype = torch.bfloat16 + def device_synchronize(self): + torch.xpu.synchronize() + def int8_double_quant( self, A: torch.Tensor, @@ -185,7 +205,19 @@ def dequantize_blockwise( blocksize: int = 4096, nested=False, ) -> torch.Tensor: - raise NotImplementedError + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if out.dtype == torch.float16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.bfloat16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.float32: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + def quantize_blockwise( self, @@ -220,7 +252,48 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - raise NotImplementedError + optim_func = None + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + assert_on_xpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + optim_func( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + g.numel() + ) + def optimizer_update_32bit( self, diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 52e56bf8e..007bdbf8e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -93,6 +93,14 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() +try: + import intel_extension_for_pytorch as ipex + + assert ipex._C._has_cpu() or ipex._C._has_xpu() + is_ipex_available = True +except Exception: + is_ipex_available = False + try: if torch.version.hip: hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) @@ -107,16 +115,20 @@ def get_native_library() -> BNBNativeLibrary: lib = get_native_library() except Exception as e: lib = None - logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) - if torch.cuda.is_available(): - logger.warning( - f""" -{BNB_BACKEND} Setup failed despite {BNB_BACKEND} being available. Please run the following command to get more information: - -python -m bitsandbytes - -Inspect the output of the command and see if you can locate {BNB_BACKEND} libraries. You might need to add them -to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes -and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues -""", + if not is_ipex_available: + logger.error( + f"Could not load bitsandbytes native library: {e}. If you use Intel CPU or XPU, please pip install intel_extension_for_pytorch", + exc_info=True, ) + if torch.cuda.is_available(): + logger.warning( + f""" + {BNB_BACKEND} Setup failed despite {BNB_BACKEND} being available. Please run the following command to get more information: + + python -m bitsandbytes + + Inspect the output of the command and see if you can locate {BNB_BACKEND} libraries. You might need to add them + to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes + and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues + """, + ) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a76aadb73..d1b3dd581 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -859,7 +859,16 @@ def dequantize_blockwise( if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != "cpu": + if A.device.type == "xpu": + backends[A.device.type].dequantize_blockwise( + A=A, + quant_state=quant_state, + absmax=absmax, + code=quant_state.code, + out=out, + blocksize=blocksize, + nested=quant_state.nested,) + elif A.device.type != "cpu": code = quant_state.code.to(A.device) supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] # Some AMD GPUs have warpsize 64 @@ -1067,7 +1076,7 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1077,7 +1086,7 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1087,8 +1096,8 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", + blocksize: Optional[int] = None, + quant_type: Optional[str] = "fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1106,9 +1115,9 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 64 if not HIP_ENVIRONMENT else 128. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to "fp4". Raises: ValueError: Raised when the input data type or blocksize is not supported. @@ -1118,9 +1127,9 @@ def dequantize_4bit( """ ensure_backend_is_available(A.device.type) if quant_state is not None: - absmax = absmax or quant_state.absmax - quant_type = quant_type or quant_state.quant_type - blocksize = blocksize or quant_state.blocksize + absmax = quant_state.absmax + quant_type = quant_state.quant_type + blocksize = quant_state.blocksize if blocksize is None: # Some AMD GPUs have warpsize 64 # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 0ea82575a..f28ef651f 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -345,7 +345,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu", "npu", "xpu", "hpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: @@ -447,7 +447,7 @@ def __init__( ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype - self.compute_type_is_set = False + self.compute_type_is_set = False if compute_dtype is None else True self.ipex_linear_is_set = False self.quant_state = None self.quant_storage = quant_storage @@ -487,6 +487,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False + self.ipex_linear_is_set = False super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias @@ -496,15 +497,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def set_ipex_linear(self, x: torch.Tensor): if ( - (x.device.type in ("cpu", "xpu")) - and not getattr(self.weight.quant_state, "ipex", False) + not getattr(self.weight.quant_state, "ipex", False) and self.weight.data.dtype == torch.uint8 and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and self.weight.quant_state.quant_type == "nf4" - and not self.training - and x.requires_grad == False ): - enable_ipex_fusion(self, x) + if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): + enable_ipex_fusion(self, x) def forward(self, x: torch.Tensor): # Check if ipex fusion can be used @@ -695,32 +694,30 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device in ("cuda", "xpu", "cpu"): + if device is not None: if device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) elif device.type == "cpu": if self.data.dtype == torch.int8: self.CB = self.data - return self else: return self.cpu() elif device.type == "xpu": if self.data.dtype == torch.int8: - self.data = self.data.contiguous().xpu(device) + self.data = self.data.contiguous() self.CB = self.data - return self - else: + if self.data.device.type == "cpu": return self.xpu(device) - else: - new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, - has_fp16_weights=self.has_fp16_weights, - ) - new_param.CB = self.CB - new_param.SCB = self.SCB - return new_param + new_param = Int8Params( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) + new_param.CB = self.CB + new_param.SCB = self.SCB + + return new_param def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 03e0e01d7..0a78b4ade 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -10,6 +10,7 @@ import torch import bitsandbytes.functional as F +from bitsandbytes.backends import backends class MockArgs: @@ -289,11 +290,11 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() + backends[p.device.type].device_synchronize() if self.is_paged: # all paged operation are asynchronous, we need # to sync to make sure all tensors are in the right state - torch.cuda.synchronize() + backends[p.device.type].device_synchronize() return loss diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index e3748685e..7d56c4ac3 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -240,10 +240,16 @@ def enable_ipex_fusion(linear, x): ) elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5): converted_weight = reverse_4bit_compress_format(linear.weight.data) - new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) new_zeros = None compensation = None + new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) + # ipex 2.7 requires new_scales is a list of tensors + if _ipex_xpu_version_prereq(2, 7): + new_scales = list(new_scales) + # ipex 2.7 can dequant converted_weight directly. + if linear.training or x.requires_grad == False: + new_weight = converted_weight else: raise ValueError( "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5" diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 4f64f6385..17b2d37d5 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -341,10 +341,10 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -#### Intel CPU +#### Intel CPU / XPU > [!TIP] -> Intel CPU backend only supports building from source; for now, please follow the instructions below. +> Intel CPU / XPU backend only supports building from source; for now, please follow the instructions below. Similar to the CUDA case, you can compile bitsandbytes from source for Linux and Windows systems.