diff --git a/qwix/_src/interception.py b/qwix/_src/interception.py index 2572658..e3b1a01 100644 --- a/qwix/_src/interception.py +++ b/qwix/_src/interception.py @@ -125,6 +125,14 @@ def wrapper(*args, **kwargs): ): interceptor[name + ".__code__"] = interceptor.pop(name) + # Disable JIT if we are intercepting the primitive bind. Primitives are + # completely hidden in C++ when JIT is enabled. + if ( + "jax.core.Primitive.bind" in interceptor + or "jax.extend.core.Primitive.bind" in interceptor + ): + need_to_disable_jit = True + # Check if JIT is already disabled. if jax.config.jax_disable_jit: need_to_disable_jit = False diff --git a/qwix/_src/providers/odml.py b/qwix/_src/providers/odml.py index a675f60..6fa414d 100644 --- a/qwix/_src/providers/odml.py +++ b/qwix/_src/providers/odml.py @@ -281,11 +281,10 @@ def get_intercept_map(self): intercept_map['jax.lax.dot_general'] = functools.partial( self._flatten_dot_general, _dot_general=intercept_map['jax.lax.dot_general'], - _reshape=intercept_map['jax.lax.reshape'], ) return intercept_map - def _flatten_dot_general(self, *args, _dot_general, _reshape, **kwargs): + def _flatten_dot_general(self, *args, _dot_general, **kwargs): """Flatten N-D weights to 2-D to support channelwise quantization.""" # This special handling is needed because tflite doesn't support multiple # quantization_dimensions. @@ -296,9 +295,9 @@ def _flatten_dot_general(self, *args, _dot_general, _reshape, **kwargs): ): args = list(args) dout = args[1].shape[1:] - args[1] = _reshape(args[1], (args[1].shape[0], np.prod(dout))) + args[1] = jax.lax.reshape(args[1], (args[1].shape[0], np.prod(dout))) out = _dot_general(*args, **kwargs) - return _reshape(out, out.shape[:-1] + dout) + return jax.lax.reshape(out, out.shape[:-1] + dout) return _dot_general(*args, **kwargs) def _fake_quant( diff --git a/qwix/_src/providers/odml_ops.py b/qwix/_src/providers/odml_ops.py index c2e423c..7c49a1a 100644 --- a/qwix/_src/providers/odml_ops.py +++ b/qwix/_src/providers/odml_ops.py @@ -66,13 +66,12 @@ def get_all_ops(): # l2_norm_full_name = l2_norm.__module__ + '.' + l2_norm.__name__ # provider._ops[l2_norm_full_name] = provider._ops['jax.numpy.tanh'] - partial = functools.partial quantize = lambda *a, **k: functools.partial(QuantizedOp, input_idx=a, **k) return { # go/keep-sorted start 'flax.linen.BatchNorm.__call__': BatchNorm, - 'flax.linen.Dropout.__call__': partial(TransparentOp, input_idx=[1]), + # flax.linen.Dropout.__call__ is handled by PrimitiveBind. 'flax.linen.GroupNorm.__call__': quantize(1, op_name='norm_op'), 'flax.linen.LayerNorm.__call__': quantize(1, op_name='norm_op'), 'flax.linen.avg_pool': OnlyInputOp, @@ -80,22 +79,15 @@ def get_all_ops(): 'flax.nnx.BatchNorm.__call__': BatchNorm, 'jax._src.numpy.indexing.rewriting_take': Take, # a.__getitem__ 'jax.custom_jvp.__call__': CustomJvpCall, # handles relu and relu6. + 'jax.extend.core.Primitive.bind': PrimitiveBind, 'jax.image.resize': OnlyInputOp, - 'jax.lax.broadcast_in_dim': TransparentOp, 'jax.lax.conv_general_dilated': DotEinsumConv, 'jax.lax.dot_general': DotEinsumConv, - 'jax.lax.reshape': TransparentOp, 'jax.lax.split': Split, - 'jax.lax.squeeze': TransparentOp, - 'jax.lax.stop_gradient': TransparentOp, - 'jax.lax.transpose': TransparentOp, - 'jax.lax.with_sharding_constraint': TransparentOp, 'jax.nn.gelu': quantize(0), 'jax.nn.leaky_relu': quantize(0), 'jax.nn.silu': Silu, 'jax.nn.softmax': Softmax, - 'jax.numpy.array': TransparentOp, - 'jax.numpy.astype': TransparentOp, 'jax.numpy.clip': OnlyOutputOp, 'jax.numpy.concatenate': Concatenate, 'jax.numpy.cos': NoQuantOp, @@ -107,7 +99,6 @@ def get_all_ops(): 'jax.numpy.pad': OnlyOutputOp, 'jax.numpy.repeat': quantize(0), # not fully supported by the converter. 'jax.numpy.sin': NoQuantOp, - 'jax.numpy.squeeze': TransparentOp, 'jax.numpy.sum': quantize(0), 'jax.numpy.take': Take, 'jax.numpy.tanh': Tanh, @@ -160,6 +151,67 @@ def get_all_ops(): FakeQuantFn = Callable[[jax.Array, qarray.HowToQuantize, str | None], jax.Array] +class PrimitiveBind: + """Intercepts jax.core.Primitive.bind to propagate aux_data.""" + + def __init__(self, **kwargs): + # This op is initialized with the same arguments as QuantizedOp, but + # we don't need them. + pass + + def __call__(self, primitive, *args, **kwargs): + """Intercepts jax.core.Primitive.bind to propagate aux_data.""" + # This interceptor is called for EVERY primitive bind. + # It checks if any input has relevant aux_data and propagates it to the + # output. + # This is needed because some ops like reshape/transpose/broadcast are + # implemented as primitives and they return new arrays, losing aux_data. + + # 1. Collect aux_data from inputs. + # We only care about specific keys that need to be propagated. + forwarded_keys = ( + _IS_ACTIVATION, + _WEIGHT_NAME, + _FQ_RULE, + _FIXED_RANGE, + _ALLOW_FUSION, + _FQ_ARRAY, + ) + aux_data_to_forward = {} + for arg in jax.tree.leaves((args, kwargs)): + if isinstance(arg, jax.Array): + for key in forwarded_keys: + if key not in aux_data_to_forward: + value = aux_data.get(arg, key, None) + if value is not None: + aux_data_to_forward[key] = value + + # special handling for fq_array: if it is 'self', we should not forward it + # because the new array is not the same as the old one. + if aux_data_to_forward.get(_FQ_ARRAY) == 'self': + del aux_data_to_forward[_FQ_ARRAY] + + # 2. Call the original bind. + out = primitive.bind(*args, **kwargs) + + # 3. Propagate aux_data to output. + if aux_data_to_forward: + + def forward(x): + if isinstance(x, jax.Array): + for key, value in aux_data_to_forward.items(): + # Don't overwrite existing aux_data if the op itself set it. + # (Though primitive.bind usually doesn't set aux_data unless + # intercepted by higher level logic, but here we are at the + # lowest level). + if aux_data.get(x, key, None) is None: + aux_data.set(x, key, value) + + jax.tree.map(forward, out) + + return out + + class QuantizedOp: """A generic quantized op that allows different scales for inputs and output. @@ -381,44 +433,6 @@ def __call__(self, *args, **kwargs): return self._fake_quant_output(out, rule) -class TransparentOp(QuantizedOp): - """A transparent op doesn't quantize anything. It only forwards aux data. - - This is mainly used for ops that doesn't change the actual value of the - inputs, e.g. reshape, transpose, etc. This is similar to OnlyOutputOp, - but it reuses the previous rule and forwards more aux data. - """ - - input_idx = [0] # default to a unary op. - forwarded_aux_data = ( - _IS_ACTIVATION, - _WEIGHT_NAME, - _FQ_RULE, - _FIXED_RANGE, - _ALLOW_FUSION, - ) - - def __call__(self, *args, **kwargs): - if len(self.input_idx) > 1: - raise ValueError( - f'Unsupported num of inputs {self.input_idx} for op {self._op_name}.' - ) - out = self._call_original_op(*args, **kwargs) - - def forward(out, arg): - for key in self.forwarded_aux_data: - value = aux_data.get(arg, key, None) - if value is not None: - aux_data.set(out, key, value) - # Also forward the FQ_ARRAY if it's used to skip the quantization. - fq_array = aux_data.get(arg, _FQ_ARRAY, None) - if fq_array == 'self': - aux_data.set(out, _FQ_ARRAY, fq_array) - - jax.tree.map(forward, out, args[self.input_idx[0]]) - return out - - class NoQuantOp(QuantizedOp): """An fp op doesn't have a corresponding quantized op.""" diff --git a/tests/_src/providers/odml_bind_test.py b/tests/_src/providers/odml_bind_test.py new file mode 100644 index 0000000..e553d01 --- /dev/null +++ b/tests/_src/providers/odml_bind_test.py @@ -0,0 +1,115 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test the ODML provider's metadata propagation.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import numpy as jnp +from qwix._src import aux_data +from qwix._src import interception +from qwix._src.providers import odml +from qwix._src.providers import odml_ops + + +wrap_func_intercepted = interception.wrap_func_intercepted + + +class OdmlBindTest(parameterized.TestCase): + + def test_metadata_propagation_reshape(self): + """Test that metadata propagates through reshape.""" + provider = odml.OdmlQatProvider([]) + + def run(): + x = jnp.zeros((2, 3)) + aux_data.set(x, odml_ops._IS_ACTIVATION, True) + aux_data.set(x, odml_ops._FQ_RULE, 'some_rule') + + # Reshape creates a new array via primitive bind + y = jnp.reshape(x, (6,)) + + self.assertTrue(aux_data.get(y, odml_ops._IS_ACTIVATION, False)) + self.assertEqual(aux_data.get(y, odml_ops._FQ_RULE, None), 'some_rule') + return y + + wrap_func_intercepted(run, provider.get_intercept_map)() + + def test_metadata_propagation_transpose(self): + """Test that metadata propagates through transpose.""" + provider = odml.OdmlQatProvider([]) + + def run(): + x = jnp.zeros((2, 3)) + aux_data.set(x, 'weight_name', 'w') + + y = jnp.transpose(x) + + self.assertEqual(aux_data.get(y, 'weight_name', None), 'w') + return y + + wrap_func_intercepted(run, provider.get_intercept_map)() + + def test_metadata_propagation_chain(self): + """Test chain of operations.""" + provider = odml.OdmlQatProvider([]) + + def run(): + x = jnp.zeros((2, 3, 4)) + aux_data.set(x, odml_ops._IS_ACTIVATION, True) + + y = jnp.transpose(x, (0, 2, 1)) # (2, 4, 3) + z = jnp.reshape(y, (8, 3)) + w = jnp.expand_dims( + z, axis=0 + ) # (1, 8, 3) --> broadcast_in_dim often used + + self.assertTrue(aux_data.get(w, odml_ops._IS_ACTIVATION, False)) + return w + + wrap_func_intercepted(run, provider.get_intercept_map)() + + def test_jit_interception(self): + """Test that interception works even with JIT (by disabling it).""" + provider = odml.OdmlQatProvider([]) + + def run(): + @jax.jit + def f(x): + # Inside JIT, primitives are usually hidden, but our interceptor + # disables JIT + y = jnp.reshape(x, (6,)) + return y + + x = jnp.zeros((2, 3)) + aux_data.set(x, 'is_activation', True) + + y = f(x) + # If JIT was successfully disabled (or intercepted), we might see the tag? + # Actually, if JIT is disabled, f runs in python, so interception works. + # If JIT is ENABLED, f returns a tracer or compiled code, and inside it + # primitive binds happen during tracing. + # Our interceptor disables JIT, so it should run eagerly or at least + # python-level. + + self.assertTrue(aux_data.get(y, odml_ops._IS_ACTIVATION, False)) + return y + + # Note: This test might fail if `jax_disable_jit` logic isn't working as + # expected or if `interception.py` doesn't pick it up. + wrap_func_intercepted(run, provider.get_intercept_map)() + + +if __name__ == '__main__': + absltest.main()