|
11 | 11 | import numpy as np |
12 | 12 | from flax import linen as nn |
13 | 13 |
|
14 | | -from utils import assert_allclose |
| 14 | +from utils import assert_allclose, pytest_parametrize_wrapper |
15 | 15 | from transformer_engine.common.recipe import ( |
16 | 16 | DelayedScaling, |
17 | 17 | MXFP8BlockScaling, |
|
22 | 22 | from transformer_engine.jax import autocast |
23 | 23 | from transformer_engine.jax.quantize import ( |
24 | 24 | get_quantize_config, |
| 25 | + get_supported_quantization_recipes, |
25 | 26 | is_scaling_mode_supported, |
26 | 27 | ScalingMode, |
27 | 28 | update_collections, |
|
32 | 33 | from transformer_engine.jax.quantize.helper import _format2dtypes |
33 | 34 | from transformer_engine.jax.sharding import MeshResource, global_mesh_resource |
34 | 35 | from transformer_engine.jax.flax.module import TransformerEngineBase |
| 36 | +from transformer_engine.jax import flax as te_flax |
| 37 | +import transformer_engine.jax as te |
35 | 38 |
|
36 | 39 | is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) |
37 | 40 | is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) |
38 | 41 | is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) |
39 | 42 |
|
| 43 | +SUPPORTED_RECIPES = get_supported_quantization_recipes() |
| 44 | + |
40 | 45 |
|
41 | 46 | def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): |
42 | 47 | """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries.""" |
@@ -253,3 +258,63 @@ def test_autocast_nvfp4_block_scaling(self): |
253 | 258 | self._compare_nvfp4_scaling_quantizers(bs) |
254 | 259 |
|
255 | 260 | self._check_default_state() |
| 261 | + |
| 262 | + |
| 263 | +class TestJaxprAndHlo: |
| 264 | + """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations.""" |
| 265 | + |
| 266 | + @pytest_parametrize_wrapper( |
| 267 | + "quantization_recipe", |
| 268 | + [ |
| 269 | + quantization_recipe |
| 270 | + for quantization_recipe in SUPPORTED_RECIPES |
| 271 | + if isinstance(quantization_recipe, NVFP4BlockScaling) |
| 272 | + ], |
| 273 | + ) |
| 274 | + def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe): |
| 275 | + """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton.""" |
| 276 | + |
| 277 | + with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()): |
| 278 | + model = te_flax.LayerNormMLP( |
| 279 | + layernorm_type="rmsnorm", |
| 280 | + return_layernorm_output=False, |
| 281 | + intermediate_dropout_rate=0.0, |
| 282 | + dtype=jnp.bfloat16, |
| 283 | + ) |
| 284 | + |
| 285 | + var_collect = model.init( |
| 286 | + jax.random.PRNGKey(0), |
| 287 | + jnp.ones((128, 128), dtype=jnp.bfloat16), |
| 288 | + ) |
| 289 | + |
| 290 | + def loss_fn(x, rngs): |
| 291 | + return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0]) |
| 292 | + |
| 293 | + x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16) |
| 294 | + rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)} |
| 295 | + jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs) |
| 296 | + |
| 297 | + rht_amax_eqns = [ |
| 298 | + eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper" |
| 299 | + ] |
| 300 | + |
| 301 | + assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}" |
| 302 | + |
| 303 | + def assert_param(index, tensor_name, expected_value: bool): |
| 304 | + if expected_value: |
| 305 | + assert rht_amax_eqns[index].params["produce_regular_amax"] == True, ( |
| 306 | + f"Expected produce_regular_amax for {tensor_name} to be True, indicating no" |
| 307 | + " reuse of amax as this tensor does not have a previous operation to fuse" |
| 308 | + " with" |
| 309 | + ) |
| 310 | + else: |
| 311 | + assert rht_amax_eqns[index].params["produce_regular_amax"] == False, ( |
| 312 | + f"Expected produce_regular_amax for {tensor_name} to be False, indicating" |
| 313 | + " reuse of amax" |
| 314 | + ) |
| 315 | + |
| 316 | + assert_param(0, "fwd ln+q", False) |
| 317 | + assert_param(1, "fwd act+q", False) |
| 318 | + # No previous op before incoming dgrad in the backward so amax is not reused |
| 319 | + assert_param(2, "bwd dgrad", True) |
| 320 | + assert_param(3, "bwd dact+q", False) |
0 commit comments