Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/chop/nn/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .block_log import block_log_quantizer
from .block_minifloat import block_minifloat_quantizer
from .integer import integer_quantizer, integer_floor_quantizer
from .scale_integer import scale_integer_quantizer
from .binary import binary_quantizer, residual_sign_quantizer
from .ternary import ternary_quantizer
from .log import log_quantizer
Expand All @@ -16,6 +17,7 @@
"block_minifloat": block_minifloat_quantizer,
"block_fp": block_fp_quantizer,
"integer": integer_quantizer,
"scale_integer": scale_integer_quantizer,
"binary": binary_quantizer,
"ternary": ternary_quantizer,
"mxint_hardware": mxint_hardware,
Expand Down
85 changes: 85 additions & 0 deletions src/chop/nn/quantizers/scale_integer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from numpy import ndarray
from torch import Tensor
import torch

from .utils import my_clamp, my_round


def _scale_integer_quantize(
x: Tensor | ndarray, width: int, is_signed: bool = True, quantile: float = None
):
"""
- Do linear quantization to input according to a scale and number of bits
- Note that `bias` can be negative or larger than `bits`
---
- forward: convert IEEE FP32/64 to fixed-point
- backward: STE
---
width: the bit width of the fixed-point number
frac_width: the number of fractional bits. Note that `bias` can be negative or larger than `bits`
---
For example: 0b101 . 00111, bits = 8, bias = 5
"""
if quantile is None:
quantile = 1.0
x_max = x.abs().max(dim=-1, keepdim=True).values + 1e-9

if is_signed:
int_min = -(2 ** (width - 1))
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
int_max = 2**width - 1

if is_signed:
scale = 2 ** (width - 1) / x_max
else:
scale = 2**width / x_max

if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale)
elif isinstance(x, int):
return x
else:
return my_clamp(my_round(x * scale), int_min, int_max) / scale


class ScaleIntegerQuantize(torch.autograd.Function):
@staticmethod
def forward(
ctx, x: Tensor, width: int, is_signed: bool = True, quantile: float = 1.0
):
return _scale_integer_quantize(
x, width=width, is_signed=is_signed, quantile=quantile
)

@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None


def scale_integer_quantizer(
x: Tensor | ndarray, width: int, is_signed: bool = True, quantile: float = 1.0
):
"""
- Do linear quantization to input according to a scale and number of bits
- Note that `bias` can be negative or larger than `bits`
---
- forward: convert IEEE FP32/64 to fixed-point
- backward: STE
---
width: the bit width of the fixed-point number
frac_width: the number of fractional bits. Note that `bias` can be negative or larger than `bits`
---
For example: 0b101 . 00111, bits = 8, bias = 5
"""
return ScaleIntegerQuantize.apply(x, width, is_signed, quantile)
Loading