From 194f63eb6fec54d7317af0469ad4a24e465750ba Mon Sep 17 00:00:00 2001 From: Cano Xiao Date: Sat, 21 Mar 2026 11:27:48 +0000 Subject: [PATCH] scale_integer support --- src/chop/nn/quantizers/__init__.py | 2 + src/chop/nn/quantizers/scale_integer.py | 85 +++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 src/chop/nn/quantizers/scale_integer.py diff --git a/src/chop/nn/quantizers/__init__.py b/src/chop/nn/quantizers/__init__.py index 8327c83fa..807838038 100644 --- a/src/chop/nn/quantizers/__init__.py +++ b/src/chop/nn/quantizers/__init__.py @@ -2,6 +2,7 @@ from .block_log import block_log_quantizer from .block_minifloat import block_minifloat_quantizer from .integer import integer_quantizer, integer_floor_quantizer +from .scale_integer import scale_integer_quantizer from .binary import binary_quantizer, residual_sign_quantizer from .ternary import ternary_quantizer from .log import log_quantizer @@ -16,6 +17,7 @@ "block_minifloat": block_minifloat_quantizer, "block_fp": block_fp_quantizer, "integer": integer_quantizer, + "scale_integer": scale_integer_quantizer, "binary": binary_quantizer, "ternary": ternary_quantizer, "mxint_hardware": mxint_hardware, diff --git a/src/chop/nn/quantizers/scale_integer.py b/src/chop/nn/quantizers/scale_integer.py new file mode 100644 index 000000000..c190e1ae4 --- /dev/null +++ b/src/chop/nn/quantizers/scale_integer.py @@ -0,0 +1,85 @@ +from numpy import ndarray +from torch import Tensor +import torch + +from .utils import my_clamp, my_round + + +def _scale_integer_quantize( + x: Tensor | ndarray, width: int, is_signed: bool = True, quantile: float = None +): + """ + - Do linear quantization to input according to a scale and number of bits + - Note that `bias` can be negative or larger than `bits` + + --- + - forward: convert IEEE FP32/64 to fixed-point + - backward: STE + + --- + width: the bit width of the fixed-point number + frac_width: the number of fractional bits. Note that `bias` can be negative or larger than `bits` + + --- + For example: 0b101 . 00111, bits = 8, bias = 5 + + """ + if quantile is None: + quantile = 1.0 + x_max = x.abs().max(dim=-1, keepdim=True).values + 1e-9 + + if is_signed: + int_min = -(2 ** (width - 1)) + int_max = 2 ** (width - 1) - 1 + else: + int_min = 0 + int_max = 2**width - 1 + + if is_signed: + scale = 2 ** (width - 1) / x_max + else: + scale = 2**width / x_max + + if isinstance(x, (Tensor, ndarray)): + return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale) + elif isinstance(x, int): + return x + else: + return my_clamp(my_round(x * scale), int_min, int_max) / scale + + +class ScaleIntegerQuantize(torch.autograd.Function): + @staticmethod + def forward( + ctx, x: Tensor, width: int, is_signed: bool = True, quantile: float = 1.0 + ): + return _scale_integer_quantize( + x, width=width, is_signed=is_signed, quantile=quantile + ) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input, None, None, None + + +def scale_integer_quantizer( + x: Tensor | ndarray, width: int, is_signed: bool = True, quantile: float = 1.0 +): + """ + - Do linear quantization to input according to a scale and number of bits + - Note that `bias` can be negative or larger than `bits` + + --- + - forward: convert IEEE FP32/64 to fixed-point + - backward: STE + + --- + width: the bit width of the fixed-point number + frac_width: the number of fractional bits. Note that `bias` can be negative or larger than `bits` + + --- + For example: 0b101 . 00111, bits = 8, bias = 5 + + """ + return ScaleIntegerQuantize.apply(x, width, is_signed, quantile)