Skip to content

Commit 4ff3eed

Browse files
[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (NVIDIA#2348)
* Add test to check jaxpr that amax is reused for nvfp4 recipe Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Move test to test_helper.py and rename file Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b14a3b6 commit 4ff3eed

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from transformer_engine.jax.activation import activation
4646
from transformer_engine.jax.dense import dense, grouped_dense
4747
from transformer_engine.jax.layernorm_dense import layernorm_dense
48-
from transformer_engine.common import recipe
4948

5049
GEMM_CASES = [
5150
(256, 256, 512),
Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from flax import linen as nn
1313

14-
from utils import assert_allclose
14+
from utils import assert_allclose, pytest_parametrize_wrapper
1515
from transformer_engine.common.recipe import (
1616
DelayedScaling,
1717
MXFP8BlockScaling,
@@ -22,6 +22,7 @@
2222
from transformer_engine.jax import autocast
2323
from transformer_engine.jax.quantize import (
2424
get_quantize_config,
25+
get_supported_quantization_recipes,
2526
is_scaling_mode_supported,
2627
ScalingMode,
2728
update_collections,
@@ -32,11 +33,15 @@
3233
from transformer_engine.jax.quantize.helper import _format2dtypes
3334
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
3435
from transformer_engine.jax.flax.module import TransformerEngineBase
36+
from transformer_engine.jax import flax as te_flax
37+
import transformer_engine.jax as te
3538

3639
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
3740
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
3841
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
3942

43+
SUPPORTED_RECIPES = get_supported_quantization_recipes()
44+
4045

4146
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
4247
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
@@ -253,3 +258,63 @@ def test_autocast_nvfp4_block_scaling(self):
253258
self._compare_nvfp4_scaling_quantizers(bs)
254259

255260
self._check_default_state()
261+
262+
263+
class TestJaxprAndHlo:
264+
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
265+
266+
@pytest_parametrize_wrapper(
267+
"quantization_recipe",
268+
[
269+
quantization_recipe
270+
for quantization_recipe in SUPPORTED_RECIPES
271+
if isinstance(quantization_recipe, NVFP4BlockScaling)
272+
],
273+
)
274+
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
275+
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
276+
277+
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
278+
model = te_flax.LayerNormMLP(
279+
layernorm_type="rmsnorm",
280+
return_layernorm_output=False,
281+
intermediate_dropout_rate=0.0,
282+
dtype=jnp.bfloat16,
283+
)
284+
285+
var_collect = model.init(
286+
jax.random.PRNGKey(0),
287+
jnp.ones((128, 128), dtype=jnp.bfloat16),
288+
)
289+
290+
def loss_fn(x, rngs):
291+
return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])
292+
293+
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
294+
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
295+
jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
296+
297+
rht_amax_eqns = [
298+
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
299+
]
300+
301+
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
302+
303+
def assert_param(index, tensor_name, expected_value: bool):
304+
if expected_value:
305+
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
306+
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
307+
" reuse of amax as this tensor does not have a previous operation to fuse"
308+
" with"
309+
)
310+
else:
311+
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
312+
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
313+
" reuse of amax"
314+
)
315+
316+
assert_param(0, "fwd ln+q", False)
317+
assert_param(1, "fwd act+q", False)
318+
# No previous op before incoming dgrad in the backward so amax is not reused
319+
assert_param(2, "bwd dgrad", True)
320+
assert_param(3, "bwd dact+q", False)

0 commit comments

Comments
 (0)