Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 91 additions & 76 deletions qwix/_src/core/conv_general_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from qwix._src.core import conv_general
from qwix._src.core import numerics
from qwix._src.core import qarray
from qwix._src.core import qarray_qt


@dataclasses.dataclass(slots=True, frozen=True, kw_only=True)
Expand All @@ -47,7 +48,7 @@ class ConvGeneralQtConfig:

# Misc.
disable_channelwise_axes: bool = False
bwd_use_original_residuals: bool = False
clip_gradients: bool = False


# Swaps the first two dimension indices of a specification.
Expand Down Expand Up @@ -138,74 +139,10 @@ def _apply_fwd_scale_to_g(


@interception.disable_interceptions
def conv_general_qt_fwd(
lhs: jax.Array,
rhs: jax.Array,
config: ConvGeneralQtConfig,
window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
lhs_dilation: Sequence[int] | None,
rhs_dilation: Sequence[int] | None,
dimension_numbers: jax.lax.ConvDimensionNumbers | None,
feature_group_count: int,
batch_group_count: int,
) -> tuple[jax.Array, tuple[qarray.MaybeQArray, qarray.MaybeQArray]]:
def conv_general_qt_fwd(lhs, rhs, config, *args):
"""Forward pass for conv_general_qt custom VJP."""
dnums = jax.lax.conv_dimension_numbers(
lhs.shape, rhs.shape, dimension_numbers
)

def _quantize_operand(
operand: jax.Array, *, for_lhs: bool
) -> qarray.MaybeQArray:
"""Quantizes a single operand for the forward pass if configured to do so."""
qtype = config.lhs_qtype if for_lhs else config.rhs_qtype
if not (qtype and numerics.should_quantize(operand.dtype)):
return operand

if for_lhs:
calibration_method = config.lhs_calibration_method
collect_quant_stat = config.lhs_collect_quant_stat
else:
calibration_method = config.rhs_calibration_method
collect_quant_stat = config.rhs_collect_quant_stat

how = conv_general.get_how_to_quantize(
dimension_numbers=dnums,
for_lhs=for_lhs,
qtype=qtype,
calibration_method=calibration_method,
)
if config.disable_channelwise_axes:
how = dataclasses.replace(how, channelwise_axes=[])

calibration = qarray.calibrate(operand, how)
if collect_quant_stat:
calibration = collect_quant_stat(calibration)
scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype)
return qarray.quantize_with_scale_zero_point(
operand, qtype, scale, zero_point
)

residuals = (lhs, rhs)
lhs = _quantize_operand(lhs, for_lhs=True)
rhs = _quantize_operand(rhs, for_lhs=False)
if not config.bwd_use_original_residuals:
residuals = (lhs, rhs)

primal_out = conv_general.conv_general_dilated(
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dnums,
feature_group_count,
batch_group_count,
)

return primal_out, residuals
del config
return conv_general.conv_general_dilated(lhs, rhs, *args), (lhs, rhs)


def conv_general_qt_bwd(
Expand All @@ -217,9 +154,15 @@ def conv_general_qt_bwd(
dimension_numbers: jax.lax.ConvDimensionNumbers | None,
feature_group_count: int,
batch_group_count: int,
res: tuple[qarray.MaybeQArray, qarray.MaybeQArray],
res: tuple[
jax.Array | qarray_qt.QArrayWithGradient,
jax.Array | qarray_qt.QArrayWithGradient,
],
g: jax.Array,
):
) -> tuple[
jax.Array | qarray_qt.QArrayWithGradient,
jax.Array | qarray_qt.QArrayWithGradient,
]:
"""Backward pass for conv_general_qt custom VJP."""
lhs, rhs = res

Expand Down Expand Up @@ -289,6 +232,10 @@ def conv_general_qt_bwd(
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
)
if isinstance(res[0], qarray_qt.QArrayWithGradient):
dlhs = dataclasses.replace(
res[0], qvalue=None, scale=None, zero_point=None, _grad=dlhs
)

# drhs
drhs_dnums = jax.lax.ConvDimensionNumbers(
Expand Down Expand Up @@ -333,11 +280,45 @@ def conv_general_qt_bwd(
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
)
if isinstance(res[1], qarray_qt.QArrayWithGradient):
drhs = dataclasses.replace(
res[1], qvalue=None, scale=None, zero_point=None, _grad=drhs
)

return dlhs, drhs


@functools.partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8, 9))
def conv_general_qt_fwd_bwd(
lhs: jax.Array | qarray_qt.QArrayWithGradient,
rhs: jax.Array | qarray_qt.QArrayWithGradient,
config: ConvGeneralQtConfig,
window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
lhs_dilation: Sequence[int] | None = None,
rhs_dilation: Sequence[int] | None = None,
dimension_numbers: jax.lax.ConvDimensionNumbers | None = None,
feature_group_count: int = 1,
batch_group_count: int = 1,
) -> jax.Array:
"""conv_general custom VJP."""
del config
return conv_general.conv_general_dilated(
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers,
feature_group_count,
batch_group_count,
)


conv_general_qt_fwd_bwd.defvjp(conv_general_qt_fwd, conv_general_qt_bwd)


def conv_general_qt(
lhs: jax.Array,
rhs: jax.Array,
Expand All @@ -350,8 +331,46 @@ def conv_general_qt(
feature_group_count: int = 1,
batch_group_count: int = 1,
) -> jax.Array:
"""Quantized conv_general using a simple, hashable config dataclass."""
result, _ = conv_general_qt_fwd(
"""Quantized conv_general_dilated with backpropagation support."""
dnums = jax.lax.conv_dimension_numbers(
lhs.shape, rhs.shape, dimension_numbers
)

def _quantize_operand(
operand: jax.Array, *, for_lhs: bool
) -> qarray.MaybeQArray:
"""Quantizes a single operand for the forward pass if configured to do so."""
qtype = config.lhs_qtype if for_lhs else config.rhs_qtype
if not (qtype and numerics.should_quantize(operand.dtype)):
return operand

if for_lhs:
calibration_method = config.lhs_calibration_method
collect_quant_stat = config.lhs_collect_quant_stat
else:
calibration_method = config.rhs_calibration_method
collect_quant_stat = config.rhs_collect_quant_stat

how = conv_general.get_how_to_quantize(
dimension_numbers=dnums,
for_lhs=for_lhs,
qtype=qtype,
calibration_method=calibration_method,
)
if config.disable_channelwise_axes:
how = dataclasses.replace(how, channelwise_axes=[])

calibration = qarray.calibrate(operand, how)
if collect_quant_stat:
calibration = collect_quant_stat(calibration)
return qarray_qt.quantize_with_calibration(
operand, qtype, calibration, clip_gradient=config.clip_gradients
)

lhs = _quantize_operand(lhs, for_lhs=True)
rhs = _quantize_operand(rhs, for_lhs=False)

return conv_general_qt_fwd_bwd(
lhs,
rhs,
config,
Expand All @@ -363,7 +382,3 @@ def conv_general_qt(
feature_group_count,
batch_group_count,
)
return result


conv_general_qt.defvjp(conv_general_qt_fwd, conv_general_qt_bwd)
Loading