diff --git a/qwix/_src/core/conv_general_qt.py b/qwix/_src/core/conv_general_qt.py index b2e8147..8eb5e3a 100644 --- a/qwix/_src/core/conv_general_qt.py +++ b/qwix/_src/core/conv_general_qt.py @@ -25,6 +25,7 @@ from qwix._src.core import conv_general from qwix._src.core import numerics from qwix._src.core import qarray +from qwix._src.core import qarray_qt @dataclasses.dataclass(slots=True, frozen=True, kw_only=True) @@ -47,7 +48,7 @@ class ConvGeneralQtConfig: # Misc. disable_channelwise_axes: bool = False - bwd_use_original_residuals: bool = False + clip_gradients: bool = False # Swaps the first two dimension indices of a specification. @@ -138,74 +139,10 @@ def _apply_fwd_scale_to_g( @interception.disable_interceptions -def conv_general_qt_fwd( - lhs: jax.Array, - rhs: jax.Array, - config: ConvGeneralQtConfig, - window_strides: Sequence[int], - padding: str | Sequence[tuple[int, int]], - lhs_dilation: Sequence[int] | None, - rhs_dilation: Sequence[int] | None, - dimension_numbers: jax.lax.ConvDimensionNumbers | None, - feature_group_count: int, - batch_group_count: int, -) -> tuple[jax.Array, tuple[qarray.MaybeQArray, qarray.MaybeQArray]]: +def conv_general_qt_fwd(lhs, rhs, config, *args): """Forward pass for conv_general_qt custom VJP.""" - dnums = jax.lax.conv_dimension_numbers( - lhs.shape, rhs.shape, dimension_numbers - ) - - def _quantize_operand( - operand: jax.Array, *, for_lhs: bool - ) -> qarray.MaybeQArray: - """Quantizes a single operand for the forward pass if configured to do so.""" - qtype = config.lhs_qtype if for_lhs else config.rhs_qtype - if not (qtype and numerics.should_quantize(operand.dtype)): - return operand - - if for_lhs: - calibration_method = config.lhs_calibration_method - collect_quant_stat = config.lhs_collect_quant_stat - else: - calibration_method = config.rhs_calibration_method - collect_quant_stat = config.rhs_collect_quant_stat - - how = conv_general.get_how_to_quantize( - dimension_numbers=dnums, - for_lhs=for_lhs, - qtype=qtype, - calibration_method=calibration_method, - ) - if config.disable_channelwise_axes: - how = dataclasses.replace(how, channelwise_axes=[]) - - calibration = qarray.calibrate(operand, how) - if collect_quant_stat: - calibration = collect_quant_stat(calibration) - scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype) - return qarray.quantize_with_scale_zero_point( - operand, qtype, scale, zero_point - ) - - residuals = (lhs, rhs) - lhs = _quantize_operand(lhs, for_lhs=True) - rhs = _quantize_operand(rhs, for_lhs=False) - if not config.bwd_use_original_residuals: - residuals = (lhs, rhs) - - primal_out = conv_general.conv_general_dilated( - lhs, - rhs, - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dnums, - feature_group_count, - batch_group_count, - ) - - return primal_out, residuals + del config + return conv_general.conv_general_dilated(lhs, rhs, *args), (lhs, rhs) def conv_general_qt_bwd( @@ -217,9 +154,15 @@ def conv_general_qt_bwd( dimension_numbers: jax.lax.ConvDimensionNumbers | None, feature_group_count: int, batch_group_count: int, - res: tuple[qarray.MaybeQArray, qarray.MaybeQArray], + res: tuple[ + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array | qarray_qt.QArrayWithGradient, + ], g: jax.Array, -): +) -> tuple[ + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array | qarray_qt.QArrayWithGradient, +]: """Backward pass for conv_general_qt custom VJP.""" lhs, rhs = res @@ -289,6 +232,10 @@ def conv_general_qt_bwd( feature_group_count=feature_group_count, batch_group_count=batch_group_count, ) + if isinstance(res[0], qarray_qt.QArrayWithGradient): + dlhs = dataclasses.replace( + res[0], qvalue=None, scale=None, zero_point=None, _grad=dlhs + ) # drhs drhs_dnums = jax.lax.ConvDimensionNumbers( @@ -333,11 +280,45 @@ def conv_general_qt_bwd( feature_group_count=feature_group_count, batch_group_count=batch_group_count, ) + if isinstance(res[1], qarray_qt.QArrayWithGradient): + drhs = dataclasses.replace( + res[1], qvalue=None, scale=None, zero_point=None, _grad=drhs + ) return dlhs, drhs @functools.partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) +def conv_general_qt_fwd_bwd( + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, + config: ConvGeneralQtConfig, + window_strides: Sequence[int], + padding: str | Sequence[tuple[int, int]], + lhs_dilation: Sequence[int] | None = None, + rhs_dilation: Sequence[int] | None = None, + dimension_numbers: jax.lax.ConvDimensionNumbers | None = None, + feature_group_count: int = 1, + batch_group_count: int = 1, +) -> jax.Array: + """conv_general custom VJP.""" + del config + return conv_general.conv_general_dilated( + lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count, + ) + + +conv_general_qt_fwd_bwd.defvjp(conv_general_qt_fwd, conv_general_qt_bwd) + + def conv_general_qt( lhs: jax.Array, rhs: jax.Array, @@ -350,8 +331,46 @@ def conv_general_qt( feature_group_count: int = 1, batch_group_count: int = 1, ) -> jax.Array: - """Quantized conv_general using a simple, hashable config dataclass.""" - result, _ = conv_general_qt_fwd( + """Quantized conv_general_dilated with backpropagation support.""" + dnums = jax.lax.conv_dimension_numbers( + lhs.shape, rhs.shape, dimension_numbers + ) + + def _quantize_operand( + operand: jax.Array, *, for_lhs: bool + ) -> qarray.MaybeQArray: + """Quantizes a single operand for the forward pass if configured to do so.""" + qtype = config.lhs_qtype if for_lhs else config.rhs_qtype + if not (qtype and numerics.should_quantize(operand.dtype)): + return operand + + if for_lhs: + calibration_method = config.lhs_calibration_method + collect_quant_stat = config.lhs_collect_quant_stat + else: + calibration_method = config.rhs_calibration_method + collect_quant_stat = config.rhs_collect_quant_stat + + how = conv_general.get_how_to_quantize( + dimension_numbers=dnums, + for_lhs=for_lhs, + qtype=qtype, + calibration_method=calibration_method, + ) + if config.disable_channelwise_axes: + how = dataclasses.replace(how, channelwise_axes=[]) + + calibration = qarray.calibrate(operand, how) + if collect_quant_stat: + calibration = collect_quant_stat(calibration) + return qarray_qt.quantize_with_calibration( + operand, qtype, calibration, clip_gradient=config.clip_gradients + ) + + lhs = _quantize_operand(lhs, for_lhs=True) + rhs = _quantize_operand(rhs, for_lhs=False) + + return conv_general_qt_fwd_bwd( lhs, rhs, config, @@ -363,7 +382,3 @@ def conv_general_qt( feature_group_count, batch_group_count, ) - return result - - -conv_general_qt.defvjp(conv_general_qt_fwd, conv_general_qt_bwd) diff --git a/qwix/_src/core/dot_general_qt.py b/qwix/_src/core/dot_general_qt.py index 0469e31..7c31fc1 100644 --- a/qwix/_src/core/dot_general_qt.py +++ b/qwix/_src/core/dot_general_qt.py @@ -24,6 +24,7 @@ from qwix._src.core import dot_general from qwix._src.core import numerics from qwix._src.core import qarray +from qwix._src.core import qarray_qt @dataclasses.dataclass(slots=True, frozen=True, kw_only=True) @@ -43,19 +44,17 @@ class DotGeneralQtConfig: dlhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient dlhs_grad_calibration_method: str = 'absmax' dlhs_tile_size: int | float | None = None + dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None # Backward pass (drhs). drhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient drhs_grad_calibration_method: str = 'absmax' drhs_tile_size: int | float | None = None + drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None # Misc. disable_channelwise_axes: bool = False - bwd_use_original_residuals: bool = False # what to use as residuals - - # Configs for stochastic rounding. - dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None - drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None + clip_gradients: bool = False def _ranges_like(*xs): @@ -124,64 +123,23 @@ def _apply_rhs_scale_to_lhs(lhs, rhs_scale, dnums): # disable interceptions for dot_general_qt_fwd. @interception.disable_interceptions def dot_general_qt_fwd( - lhs: jax.Array, - rhs: jax.Array, + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, dimension_numbers: jax.lax.DotDimensionNumbers, config: DotGeneralQtConfig, ): """Forward pass for dot_general_qt custom VJP.""" - ndims = (lhs.ndim, rhs.ndim) - - def _quantize_operand(operand: jax.Array, is_lhs: bool) -> qarray.MaybeQArray: - """Quantizes a single operand for the forward pass if configured to do so.""" - if is_lhs: - qtype = config.lhs_qtype - calibration_method = config.lhs_calibration_method - collect_quant_stat = config.lhs_collect_quant_stat - else: - qtype = config.rhs_qtype - calibration_method = config.rhs_calibration_method - collect_quant_stat = config.rhs_collect_quant_stat - - if not (qtype and numerics.should_quantize(operand.dtype)): - return operand - - how = dot_general.get_how_to_quantize( - dimension_numbers=dimension_numbers, - ndims=ndims, - for_lhs=is_lhs, - qtype=qtype, - tile_size=config.tile_size, - calibration_method=calibration_method, - ) - if config.disable_channelwise_axes: - how = dataclasses.replace(how, channelwise_axes=[]) - - calibration = qarray.calibrate(operand, how) - if collect_quant_stat: - calibration = collect_quant_stat(calibration) - scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype) - return qarray.quantize_with_scale_zero_point( - operand, how.qtype, scale, zero_point - ) - - qlhs = _quantize_operand(lhs, is_lhs=True) - qrhs = _quantize_operand(rhs, is_lhs=False) - - primal_out = dot_general.dot_general(qlhs, qrhs, dimension_numbers) - - if config.bwd_use_original_residuals: - residuals = (lhs, rhs) - else: - residuals = (qlhs, qrhs) - - return primal_out, residuals + del config + return dot_general.dot_general(lhs, rhs, dimension_numbers), (lhs, rhs) def dot_general_qt_bwd( fwd_dimension_numbers: jax.lax.DotDimensionNumbers, config: DotGeneralQtConfig, - residuals: tuple[qarray.MaybeQArray, qarray.MaybeQArray], + residuals: tuple[ + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array | qarray_qt.QArrayWithGradient, + ], g: jax.Array, ): """Backward pass for dot_general_qt custom VJP.""" @@ -189,8 +147,8 @@ def dot_general_qt_bwd( def _compute_gradient_for_operand( g: jax.Array, y: qarray.MaybeQArray, *, for_dlhs: bool - ): - """Compute dot_general for gradient and other_fwd_operand.""" + ) -> jax.Array | qarray_qt.QArrayWithGradient: + """Compute dx from g and y.""" bwd_dnums, transpose_axes = _update_dimension_numbers_for_backward( fwd_dimension_numbers, (lhs.ndim, rhs.ndim), for_dlhs=for_dlhs ) @@ -198,10 +156,14 @@ def _compute_gradient_for_operand( g_qtype = config.dlhs_grad_qtype g_tile_size = config.dlhs_tile_size g_calibration_method = config.dlhs_grad_calibration_method + g_noise_fn = config.dlhs_stochastic_rounding_noise_fn + result_type = lhs # the result gradient must match this type. else: g_qtype = config.drhs_grad_qtype g_tile_size = config.drhs_tile_size g_calibration_method = config.drhs_grad_calibration_method + g_noise_fn = config.drhs_stochastic_rounding_noise_fn + result_type = rhs # the result gradient must match this type. if g_qtype and numerics.should_quantize(g.dtype): if isinstance(y, qarray.QArray) and not qarray.get_tiled_axes(y): @@ -219,23 +181,20 @@ def _compute_gradient_for_operand( tile_size=g_tile_size, calibration_method=g_calibration_method, ) + g_how = dataclasses.replace(g_how, noise_fn=g_noise_fn) if config.disable_channelwise_axes: g_how = dataclasses.replace(g_how, channelwise_axes=[]) - if for_dlhs and config.dlhs_stochastic_rounding_noise_fn: - g_how = dataclasses.replace( - g_how, - noise_fn=config.dlhs_stochastic_rounding_noise_fn, - ) - if not for_dlhs and config.drhs_stochastic_rounding_noise_fn: - g_how = dataclasses.replace( - g_how, - noise_fn=config.drhs_stochastic_rounding_noise_fn, - ) g = qarray.quantize(g, g_how) grad_res = dot_general.dot_general(g, y, bwd_dnums) - return jax.lax.transpose(grad_res, transpose_axes) + grad_res = jax.lax.transpose(grad_res, transpose_axes) + if isinstance(result_type, qarray_qt.QArrayWithGradient): + return dataclasses.replace( + result_type, qvalue=None, scale=None, zero_point=None, _grad=grad_res + ) + else: + return grad_res dlhs = _compute_gradient_for_operand(g, rhs, for_dlhs=True) drhs = _compute_gradient_for_operand(g, lhs, for_dlhs=False) @@ -244,6 +203,20 @@ def _compute_gradient_for_operand( @functools.partial(jax.custom_vjp, nondiff_argnums=(2, 3)) +def dot_general_qt_fwd_bwd( + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, + dimension_numbers: jax.lax.DotDimensionNumbers, + config: DotGeneralQtConfig, +) -> jax.Array: + """dot_general custom VJP.""" + del config + return dot_general.dot_general(lhs, rhs, dimension_numbers) + + +dot_general_qt_fwd_bwd.defvjp(dot_general_qt_fwd, dot_general_qt_bwd) + + def dot_general_qt( lhs: jax.Array, rhs: jax.Array, @@ -251,8 +224,42 @@ def dot_general_qt( config: DotGeneralQtConfig, ) -> jax.Array: """Quantized dot_general with backpropagation support.""" - result, _ = dot_general_qt_fwd(lhs, rhs, dimension_numbers, config) - return result + if config.lhs_qtype and numerics.should_quantize(lhs.dtype): + how = dot_general.get_how_to_quantize( + dimension_numbers=dimension_numbers, + ndims=(lhs.ndim, rhs.ndim), + for_lhs=True, + qtype=config.lhs_qtype, + tile_size=config.tile_size, + calibration_method=config.lhs_calibration_method, + ) + if config.disable_channelwise_axes: + how = dataclasses.replace(how, channelwise_axes=[]) + + calibration = qarray.calibrate(lhs, how) + if config.lhs_collect_quant_stat: + calibration = config.lhs_collect_quant_stat(calibration) + lhs = qarray_qt.quantize_with_calibration( + lhs, how.qtype, calibration, clip_gradient=config.clip_gradients + ) + + if config.rhs_qtype and numerics.should_quantize(rhs.dtype): + how = dot_general.get_how_to_quantize( + dimension_numbers=dimension_numbers, + ndims=(lhs.ndim, rhs.ndim), + for_lhs=False, + qtype=config.rhs_qtype, + tile_size=config.tile_size, + calibration_method=config.rhs_calibration_method, + ) + if config.disable_channelwise_axes: + how = dataclasses.replace(how, channelwise_axes=[]) + calibration = qarray.calibrate(rhs, how) + if config.rhs_collect_quant_stat: + calibration = config.rhs_collect_quant_stat(calibration) + rhs = qarray_qt.quantize_with_calibration( + rhs, how.qtype, calibration, clip_gradient=config.clip_gradients + ) -dot_general_qt.defvjp(dot_general_qt_fwd, dot_general_qt_bwd) + return dot_general_qt_fwd_bwd(lhs, rhs, dimension_numbers, config) diff --git a/qwix/_src/core/qarray_qt.py b/qwix/_src/core/qarray_qt.py new file mode 100644 index 0000000..e8910bc --- /dev/null +++ b/qwix/_src/core/qarray_qt.py @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""QArray with gradient for custom VJP.""" + +import dataclasses +from typing import Mapping +import flax.struct +import jax +from qwix._src.core import qarray + + +@flax.struct.dataclass +class QArrayWithGradient(qarray.QArray): + """QArray with gradient. + + This dataclass allows us to associate a gradient with the QArray. It's + achieved by defining an extra attribute `_grad` on the QArray, which has the + same dtype and the same shape as the unquantized array. In forward pass, the + `_grad` does nothing and should never be consumed. In backward pass, the + `_grad` carries the gradient of the whole QArray. + + This approach overcomes the Jax limitation on the gradients, i.e., the + gradient of a qvalue of int8[128,128] has to be float0[128,128], while the + gradient of a scale of float32[1,1] has to be float32[1,1]. An alternative + is to define the QArray as a new Hijax type, which is more complex. + """ + + _grad: jax.Array = flax.struct.field(kw_only=True) + + +def quantize_with_calibration( + array: jax.Array, + qtype: jax.typing.DTypeLike, + calibration: Mapping[str, jax.Array], + clip_gradient: bool = False, +) -> QArrayWithGradient: + """Quantizes an array with calibration with backpropagation support. + + Args: + array: The array to quantize. + qtype: The quantized type. + calibration: The calibration of the array. + clip_gradient: Whether to clip the straight-through estimator to the + calibration range, i.e., the gradient outside the calibration range is 0. + + Returns: + The quantized array with backpropagation support. + """ + scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype) + res = qarray.quantize_with_scale_zero_point(array, qtype, scale, zero_point) + if clip_gradient: + array = qarray.clip_to_calibration( + array, calibration, qarray.get_tiled_axes(res) + ) + # Do not allow gradients on the quantized array to flow back to the input. + res = jax.lax.stop_gradient(res) + return QArrayWithGradient(**dataclasses.asdict(res), _grad=array) + + +@jax.custom_jvp +def dequantize(array: QArrayWithGradient) -> jax.Array: + """Dequantizes an array.""" + return qarray.dequantize(array) + + +@dequantize.defjvp +def _dequantize_jvp(primals, tangents): + return dequantize(*primals), tangents[0]._grad # pylint: disable=protected-access diff --git a/qwix/_src/core/ragged_dot_qt.py b/qwix/_src/core/ragged_dot_qt.py index 1521965..005bfe3 100644 --- a/qwix/_src/core/ragged_dot_qt.py +++ b/qwix/_src/core/ragged_dot_qt.py @@ -18,7 +18,9 @@ import functools import jax from qwix._src import interception +from qwix._src.core import numerics from qwix._src.core import qarray +from qwix._src.core import qarray_qt from qwix._src.core import ragged_dot @@ -34,11 +36,14 @@ class RaggedDotQtConfig: dlhs_grad_qtype: jax.typing.DTypeLike | None = None drhs_grad_qtype: jax.typing.DTypeLike | None = None + # Misc. + clip_gradients: bool = False + @interception.disable_interceptions def ragged_dot_qt_fwd( - lhs: jax.Array, - rhs: jax.Array, + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, group_sizes: jax.Array, config: RaggedDotQtConfig, precision: jax.lax.PrecisionLike = None, @@ -46,20 +51,11 @@ def ragged_dot_qt_fwd( group_offset: jax.Array | None = None, ): """Forward pass for ragged_dot_qt custom VJP.""" - qlhs = qarray.quantize( - # lhs shape [M, K]: contracting axis=1, channelwise axis=0 - lhs, - qarray.HowToQuantize(qtype=config.lhs_qtype, channelwise_axes=[0]), - ) - qrhs = qarray.quantize( - # rhs shape [G, K, N]: contracting axis=1, channelwise axes=2 - rhs, - qarray.HowToQuantize(qtype=config.rhs_qtype, channelwise_axes=[2]), - ) + del config primal_out = ragged_dot.ragged_dot( - qlhs, qrhs, group_sizes, precision, preferred_element_type, group_offset + lhs, rhs, group_sizes, precision, preferred_element_type, group_offset ) - return primal_out, (qlhs, qrhs, group_sizes) + return primal_out, (lhs, rhs, group_sizes) def ragged_dot_qt_bwd( @@ -69,9 +65,17 @@ def ragged_dot_qt_bwd( preferred_element_type: jax.typing.DTypeLike | None, group_offset: jax.Array | None, # Residuals from fwd pass - residuals: tuple[qarray.MaybeQArray, qarray.MaybeQArray, jax.Array], + residuals: tuple[ + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array, + ], g: jax.Array, -) -> tuple[jax.Array, jax.Array, None]: +) -> tuple[ + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array | qarray_qt.QArrayWithGradient, + None, +]: """Backward pass for ragged_dot_qt custom VJP.""" (lhs, rhs, group_sizes) = residuals # lhs [M, K], rhs [G, K, N], g [M, N] @@ -99,6 +103,10 @@ def ragged_dot_qt_bwd( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + if isinstance(residuals[0], qarray_qt.QArrayWithGradient): + dlhs = dataclasses.replace( + residuals[0], qvalue=None, scale=None, zero_point=None, _grad=dlhs + ) # drhs = ragged_dot_general(lhs, g) # [G, K, N] = [M, K] @ [M, N] @@ -127,11 +135,39 @@ def ragged_dot_qt_bwd( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + if isinstance(residuals[1], qarray_qt.QArrayWithGradient): + drhs = dataclasses.replace( + residuals[1], qvalue=None, scale=None, zero_point=None, _grad=drhs + ) return dlhs, drhs, None @functools.partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def ragged_dot_qt_fwd_bwd( + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, + group_sizes: jax.Array, + config: RaggedDotQtConfig, + precision: jax.lax.PrecisionLike = None, + preferred_element_type: jax.typing.DTypeLike | None = None, + group_offset: jax.Array | None = None, +) -> jax.Array: + """ragged_dot custom VJP.""" + del config + return ragged_dot.ragged_dot( + lhs, + rhs, + group_sizes, + precision, + preferred_element_type, + group_offset, + ) + + +ragged_dot_qt_fwd_bwd.defvjp(ragged_dot_qt_fwd, ragged_dot_qt_bwd) + + def ragged_dot_qt( lhs: jax.Array, rhs: jax.Array, @@ -142,7 +178,21 @@ def ragged_dot_qt( group_offset: jax.Array | None = None, ) -> jax.Array: """Quantized ragged_dot with backpropagation support.""" - result, _ = ragged_dot_qt_fwd( + if config.lhs_qtype and numerics.should_quantize(lhs.dtype): + # lhs shape [M, K]: contracting axis=1, channelwise axis=0 + lhs_how = qarray.HowToQuantize(qtype=config.lhs_qtype, channelwise_axes=[0]) + calibration = qarray.calibrate(lhs, lhs_how) + lhs = qarray_qt.quantize_with_calibration( + lhs, lhs_how.qtype, calibration, clip_gradient=config.clip_gradients + ) + if config.rhs_qtype and numerics.should_quantize(rhs.dtype): + # rhs shape [G, K, N]: contracting axis=1, channelwise axes=2 + rhs_how = qarray.HowToQuantize(qtype=config.rhs_qtype, channelwise_axes=[2]) + calibration = qarray.calibrate(rhs, rhs_how) + rhs = qarray_qt.quantize_with_calibration( + rhs, rhs_how.qtype, calibration, clip_gradient=config.clip_gradients + ) + return ragged_dot_qt_fwd_bwd( lhs, rhs, group_sizes, @@ -151,7 +201,3 @@ def ragged_dot_qt( preferred_element_type, group_offset, ) - return result - - -ragged_dot_qt.defvjp(ragged_dot_qt_fwd, ragged_dot_qt_bwd) diff --git a/qwix/_src/providers/qt.py b/qwix/_src/providers/qt.py index 1f0d0cd..6032045 100644 --- a/qwix/_src/providers/qt.py +++ b/qwix/_src/providers/qt.py @@ -15,7 +15,7 @@ import dataclasses import functools -from typing import Any, Callable, Mapping, Sequence +from typing import Callable, Sequence import jax from jax import numpy as jnp @@ -34,7 +34,7 @@ class QtRule(qconfig.QuantizationRule): # In backward pass, quantize the gradients to the given type. This doesn't # affect the residuals as the residuals will reuse the quantization in the - # forward pass, unless bwd_use_original_residuals is set. + # forward pass. bwd_qtype: jax.typing.DTypeLike | None = None # In backward pass, calibrate the gradients using the given method. @@ -48,11 +48,6 @@ class QtRule(qconfig.QuantizationRule): # If True, disable channelwise axes for both forward and backward passes. disable_channelwise_axes: bool = False - # If True, use the original values instead of the quantized values as the - # residuals for backward pass. Enabling this prevents using low-precision - # matmuls during bwd pass and has a negative impact on performance. - bwd_use_original_residuals: bool = False - # Use stochastic rounding for the gradients. (Only 'uniform' is supported.) bwd_stochastic_rounding: str | None = None @@ -60,10 +55,12 @@ class QtRule(qconfig.QuantizationRule): # noise for the 0th dimension and broadcast it over remaining dimensions. channelwise_noise_axes: Sequence[int] = (0,) - # Override any fields in DotGeneralQtConfig or ConvGeneralQtConfig. This is - # highly experimental and subjects to changes with no backward compatibility - # guarantees. - additional_qt_config: Mapping[str, Any] | None = None + # Whether to apply clipping to the gradients. If True, values outside of the + # calibration range will have 0 gradients. This is not needed if calibration + # method is "absmax" or "minmax" because there will be no out-of-calibration + # values by definition. Enabling this may improve the quality but at the cost + # of additional computation. + clip_gradients: bool = False class QtProvider(qconfig.QuantizationProvider): @@ -293,7 +290,7 @@ def _create_conv_general_qt_config( drhs_grad_calibration_method=rule.bwd_calibration_method, # misc. disable_channelwise_axes=rule.disable_channelwise_axes, - bwd_use_original_residuals=rule.bwd_use_original_residuals, + clip_gradients=rule.clip_gradients, ) def _create_dot_general_qt_config( @@ -386,19 +383,17 @@ def _create_dot_general_qt_config( dlhs_grad_qtype=rule.bwd_qtype, dlhs_grad_calibration_method=rule.bwd_calibration_method, dlhs_tile_size=dlhs_tile_size, + dlhs_stochastic_rounding_noise_fn=dlhs_stochastic_rounding_noise_fn, # drhs configs. drhs_grad_qtype=rule.bwd_qtype, drhs_tile_size=drhs_tile_size, drhs_grad_calibration_method=rule.bwd_calibration_method, + drhs_stochastic_rounding_noise_fn=drhs_stochastic_rounding_noise_fn, # misc. disable_channelwise_axes=rule.disable_channelwise_axes, - bwd_use_original_residuals=rule.bwd_use_original_residuals, - dlhs_stochastic_rounding_noise_fn=dlhs_stochastic_rounding_noise_fn, - drhs_stochastic_rounding_noise_fn=drhs_stochastic_rounding_noise_fn, + clip_gradients=rule.clip_gradients, ) - if rule.additional_qt_config: - qt_config = dataclasses.replace(qt_config, **rule.additional_qt_config) return qt_config def _create_ragged_dot_qt_config( @@ -415,4 +410,6 @@ def _create_ragged_dot_qt_config( # bwd configs. dlhs_grad_qtype=rule.bwd_qtype, drhs_grad_qtype=rule.bwd_qtype, + # misc. + clip_gradients=rule.clip_gradients, ) diff --git a/tests/core/qarray_qt_test.py b/tests/core/qarray_qt_test.py new file mode 100644 index 0000000..8797937 --- /dev/null +++ b/tests/core/qarray_qt_test.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import numpy as jnp +from qwix._src.core import qarray +from qwix._src.core import qarray_qt + + +class QArrayQtTest(parameterized.TestCase): + + def test_qarray_with_gradient(self): + x = jnp.ones((3, 3), jnp.float32) + + def fake_quant_sum(x): + how = qarray.HowToQuantize(qtype=jnp.int8) + x = qarray_qt.quantize_with_calibration( + x, how.qtype, qarray.calibrate(x, how) + ) + x = qarray_qt.dequantize(x) + return jnp.sum(x) + + self.assertTrue((jax.grad(fake_quant_sum)(x) == x).all()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/providers/qt_test.py b/tests/providers/qt_test.py index 94adca3..df3ced3 100644 --- a/tests/providers/qt_test.py +++ b/tests/providers/qt_test.py @@ -65,19 +65,21 @@ def loss_fn(params): self.assertEqual(quant_stats["dot_general0_lhs"]["count"], 1) def test_srq_jit_grad_nnx(self): - """Test SRQ on NNX module.""" - linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0), param_dtype=jnp.bfloat16) - qt_provider = qt.QtProvider([ - qconfig.QuantizationRule( - module_path=".*", - weight_qtype=jnp.int8, - act_qtype=jnp.int8, - act_static_scale=True, - ), - ]) + """Test creating and train an SRQ NNX model inside jit.""" + + def create_srq_nnx_model(model_input): + linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0), param_dtype=jnp.bfloat16) + qt_provider = qt.QtProvider([ + qconfig.QuantizationRule( + weight_qtype=jnp.int8, + act_qtype=jnp.int8, + act_static_scale=True, + ), + ]) + return qwix_model.quantize_model(linear, qt_provider, model_input) model_input = jnp.ones((10, 12), dtype=jnp.float32) - qt_linear = qwix_model.quantize_model(linear, qt_provider, model_input) + qt_linear = nnx.jit(create_srq_nnx_model)(model_input) quant_stats = nnx.variables(qt_linear, flax_util.QuantStat) # quant_stats should be initialized but empty.