Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions qwix/_src/interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def wrapper(*args, **kwargs):
):
interceptor[name + ".__code__"] = interceptor.pop(name)

# Disable JIT if we are intercepting the primitive bind. Primitives are
# completely hidden in C++ when JIT is enabled.
if (
"jax.core.Primitive.bind" in interceptor
or "jax.extend.core.Primitive.bind" in interceptor
):
need_to_disable_jit = True

# Check if JIT is already disabled.
if jax.config.jax_disable_jit:
need_to_disable_jit = False
Expand Down
7 changes: 3 additions & 4 deletions qwix/_src/providers/odml.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,10 @@ def get_intercept_map(self):
intercept_map['jax.lax.dot_general'] = functools.partial(
self._flatten_dot_general,
_dot_general=intercept_map['jax.lax.dot_general'],
_reshape=intercept_map['jax.lax.reshape'],
)
return intercept_map

def _flatten_dot_general(self, *args, _dot_general, _reshape, **kwargs):
def _flatten_dot_general(self, *args, _dot_general, **kwargs):
"""Flatten N-D weights to 2-D to support channelwise quantization."""
# This special handling is needed because tflite doesn't support multiple
# quantization_dimensions.
Expand All @@ -296,9 +295,9 @@ def _flatten_dot_general(self, *args, _dot_general, _reshape, **kwargs):
):
args = list(args)
dout = args[1].shape[1:]
args[1] = _reshape(args[1], (args[1].shape[0], np.prod(dout)))
args[1] = jax.lax.reshape(args[1], (args[1].shape[0], np.prod(dout)))
out = _dot_general(*args, **kwargs)
return _reshape(out, out.shape[:-1] + dout)
return jax.lax.reshape(out, out.shape[:-1] + dout)
return _dot_general(*args, **kwargs)

def _fake_quant(
Expand Down
112 changes: 63 additions & 49 deletions qwix/_src/providers/odml_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,36 +66,28 @@ def get_all_ops():
# l2_norm_full_name = l2_norm.__module__ + '.' + l2_norm.__name__
# provider._ops[l2_norm_full_name] = provider._ops['jax.numpy.tanh']

partial = functools.partial
quantize = lambda *a, **k: functools.partial(QuantizedOp, input_idx=a, **k)

return {
# go/keep-sorted start
'flax.linen.BatchNorm.__call__': BatchNorm,
'flax.linen.Dropout.__call__': partial(TransparentOp, input_idx=[1]),
# flax.linen.Dropout.__call__ is handled by PrimitiveBind.
'flax.linen.GroupNorm.__call__': quantize(1, op_name='norm_op'),
'flax.linen.LayerNorm.__call__': quantize(1, op_name='norm_op'),
'flax.linen.avg_pool': OnlyInputOp,
'flax.linen.max_pool': OnlyInputOp,
'flax.nnx.BatchNorm.__call__': BatchNorm,
'jax._src.numpy.indexing.rewriting_take': Take, # a.__getitem__
'jax.custom_jvp.__call__': CustomJvpCall, # handles relu and relu6.
'jax.extend.core.Primitive.bind': PrimitiveBind,
'jax.image.resize': OnlyInputOp,
'jax.lax.broadcast_in_dim': TransparentOp,
'jax.lax.conv_general_dilated': DotEinsumConv,
'jax.lax.dot_general': DotEinsumConv,
'jax.lax.reshape': TransparentOp,
'jax.lax.split': Split,
'jax.lax.squeeze': TransparentOp,
'jax.lax.stop_gradient': TransparentOp,
'jax.lax.transpose': TransparentOp,
'jax.lax.with_sharding_constraint': TransparentOp,
'jax.nn.gelu': quantize(0),
'jax.nn.leaky_relu': quantize(0),
'jax.nn.silu': Silu,
'jax.nn.softmax': Softmax,
'jax.numpy.array': TransparentOp,
'jax.numpy.astype': TransparentOp,
'jax.numpy.clip': OnlyOutputOp,
'jax.numpy.concatenate': Concatenate,
'jax.numpy.cos': NoQuantOp,
Expand All @@ -107,7 +99,6 @@ def get_all_ops():
'jax.numpy.pad': OnlyOutputOp,
'jax.numpy.repeat': quantize(0), # not fully supported by the converter.
'jax.numpy.sin': NoQuantOp,
'jax.numpy.squeeze': TransparentOp,
'jax.numpy.sum': quantize(0),
'jax.numpy.take': Take,
'jax.numpy.tanh': Tanh,
Expand Down Expand Up @@ -160,6 +151,67 @@ def get_all_ops():
FakeQuantFn = Callable[[jax.Array, qarray.HowToQuantize, str | None], jax.Array]


class PrimitiveBind:
"""Intercepts jax.core.Primitive.bind to propagate aux_data."""

def __init__(self, **kwargs):
# This op is initialized with the same arguments as QuantizedOp, but
# we don't need them.
pass

def __call__(self, primitive, *args, **kwargs):
"""Intercepts jax.core.Primitive.bind to propagate aux_data."""
# This interceptor is called for EVERY primitive bind.
# It checks if any input has relevant aux_data and propagates it to the
# output.
# This is needed because some ops like reshape/transpose/broadcast are
# implemented as primitives and they return new arrays, losing aux_data.

# 1. Collect aux_data from inputs.
# We only care about specific keys that need to be propagated.
forwarded_keys = (
_IS_ACTIVATION,
_WEIGHT_NAME,
_FQ_RULE,
_FIXED_RANGE,
_ALLOW_FUSION,
_FQ_ARRAY,
)
aux_data_to_forward = {}
for arg in jax.tree.leaves((args, kwargs)):
if isinstance(arg, jax.Array):
for key in forwarded_keys:
if key not in aux_data_to_forward:
value = aux_data.get(arg, key, None)
if value is not None:
aux_data_to_forward[key] = value

# special handling for fq_array: if it is 'self', we should not forward it
# because the new array is not the same as the old one.
if aux_data_to_forward.get(_FQ_ARRAY) == 'self':
del aux_data_to_forward[_FQ_ARRAY]

# 2. Call the original bind.
out = primitive.bind(*args, **kwargs)

# 3. Propagate aux_data to output.
if aux_data_to_forward:

def forward(x):
if isinstance(x, jax.Array):
for key, value in aux_data_to_forward.items():
# Don't overwrite existing aux_data if the op itself set it.
# (Though primitive.bind usually doesn't set aux_data unless
# intercepted by higher level logic, but here we are at the
# lowest level).
if aux_data.get(x, key, None) is None:
aux_data.set(x, key, value)

jax.tree.map(forward, out)

return out


class QuantizedOp:
"""A generic quantized op that allows different scales for inputs and output.

Expand Down Expand Up @@ -381,44 +433,6 @@ def __call__(self, *args, **kwargs):
return self._fake_quant_output(out, rule)


class TransparentOp(QuantizedOp):
"""A transparent op doesn't quantize anything. It only forwards aux data.

This is mainly used for ops that doesn't change the actual value of the
inputs, e.g. reshape, transpose, etc. This is similar to OnlyOutputOp,
but it reuses the previous rule and forwards more aux data.
"""

input_idx = [0] # default to a unary op.
forwarded_aux_data = (
_IS_ACTIVATION,
_WEIGHT_NAME,
_FQ_RULE,
_FIXED_RANGE,
_ALLOW_FUSION,
)

def __call__(self, *args, **kwargs):
if len(self.input_idx) > 1:
raise ValueError(
f'Unsupported num of inputs {self.input_idx} for op {self._op_name}.'
)
out = self._call_original_op(*args, **kwargs)

def forward(out, arg):
for key in self.forwarded_aux_data:
value = aux_data.get(arg, key, None)
if value is not None:
aux_data.set(out, key, value)
# Also forward the FQ_ARRAY if it's used to skip the quantization.
fq_array = aux_data.get(arg, _FQ_ARRAY, None)
if fq_array == 'self':
aux_data.set(out, _FQ_ARRAY, fq_array)

jax.tree.map(forward, out, args[self.input_idx[0]])
return out


class NoQuantOp(QuantizedOp):
"""An fp op doesn't have a corresponding quantized op."""

Expand Down
115 changes: 115 additions & 0 deletions tests/_src/providers/odml_bind_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the ODML provider's metadata propagation."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
from qwix._src import aux_data
from qwix._src import interception
from qwix._src.providers import odml
from qwix._src.providers import odml_ops


wrap_func_intercepted = interception.wrap_func_intercepted


class OdmlBindTest(parameterized.TestCase):

def test_metadata_propagation_reshape(self):
"""Test that metadata propagates through reshape."""
provider = odml.OdmlQatProvider([])

def run():
x = jnp.zeros((2, 3))
aux_data.set(x, odml_ops._IS_ACTIVATION, True)
aux_data.set(x, odml_ops._FQ_RULE, 'some_rule')

# Reshape creates a new array via primitive bind
y = jnp.reshape(x, (6,))

self.assertTrue(aux_data.get(y, odml_ops._IS_ACTIVATION, False))
self.assertEqual(aux_data.get(y, odml_ops._FQ_RULE, None), 'some_rule')
return y

wrap_func_intercepted(run, provider.get_intercept_map)()

def test_metadata_propagation_transpose(self):
"""Test that metadata propagates through transpose."""
provider = odml.OdmlQatProvider([])

def run():
x = jnp.zeros((2, 3))
aux_data.set(x, 'weight_name', 'w')

y = jnp.transpose(x)

self.assertEqual(aux_data.get(y, 'weight_name', None), 'w')
return y

wrap_func_intercepted(run, provider.get_intercept_map)()

def test_metadata_propagation_chain(self):
"""Test chain of operations."""
provider = odml.OdmlQatProvider([])

def run():
x = jnp.zeros((2, 3, 4))
aux_data.set(x, odml_ops._IS_ACTIVATION, True)

y = jnp.transpose(x, (0, 2, 1)) # (2, 4, 3)
z = jnp.reshape(y, (8, 3))
w = jnp.expand_dims(
z, axis=0
) # (1, 8, 3) --> broadcast_in_dim often used

self.assertTrue(aux_data.get(w, odml_ops._IS_ACTIVATION, False))
return w

wrap_func_intercepted(run, provider.get_intercept_map)()

def test_jit_interception(self):
"""Test that interception works even with JIT (by disabling it)."""
provider = odml.OdmlQatProvider([])

def run():
@jax.jit
def f(x):
# Inside JIT, primitives are usually hidden, but our interceptor
# disables JIT
y = jnp.reshape(x, (6,))
return y

x = jnp.zeros((2, 3))
aux_data.set(x, 'is_activation', True)

y = f(x)
# If JIT was successfully disabled (or intercepted), we might see the tag?
# Actually, if JIT is disabled, f runs in python, so interception works.
# If JIT is ENABLED, f returns a tracer or compiled code, and inside it
# primitive binds happen during tracing.
# Our interceptor disables JIT, so it should run eagerly or at least
# python-level.

self.assertTrue(aux_data.get(y, odml_ops._IS_ACTIVATION, False))
return y

# Note: This test might fail if `jax_disable_jit` logic isn't working as
# expected or if `interception.py` doesn't pick it up.
wrap_func_intercepted(run, provider.get_intercept_map)()


if __name__ == '__main__':
absltest.main()