diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 0f3d2cbbf0..7ddbb18077 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -22,7 +22,7 @@ ) from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils from run_layer_with_overlap import _compare_tensors @@ -56,39 +56,36 @@ def get_nvfp4_quantizer_factory(): enabled. Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if role == "linear_input": - return quantization_nvfp4.NVFP4QuantizerRef( + if role.tensor_type == "input": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, - with_rht=True, # RHT enabled for input + with_rht=True, ) - elif role == "linear_weight": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "weight": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(16, 16), # 2D quantization for weight + quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": - # Output quantization not used + elif role.tensor_type == "output": return None - elif role == "linear_grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "grad_output": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, - with_rht=True, # RHT enabled for grad_output + with_rht=True, ) - elif role == "linear_grad_input": - # Grad input quantization not used + elif role.tensor_type == "grad_input": return None else: - # For any other roles, return None return None return factory diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..8afb103056 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,7 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 5f35e9ad10..2823172d35 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -13,7 +13,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index a96fea3af0..d727433a28 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -6,7 +6,7 @@ import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils @@ -76,39 +76,36 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights) Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if role == "linear_input": - return quantization_nvfp4.NVFP4QuantizerRef( + if role.tensor_type == "input": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif role == "linear_weight": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "weight": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": - # Output quantization not used + elif role.tensor_type == "output": return None - elif role == "linear_grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( + elif role.tensor_type == "grad_output": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif role == "linear_grad_input": - # Grad input quantization not used + elif role.tensor_type == "grad_input": return None else: - # For any other roles, return None return None return factory diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..99f0f5cdd6 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -7,7 +7,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f54..a9178b25d8 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -12,7 +12,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 536d43adc0..8231d013d3 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -17,8 +17,16 @@ GroupedLinear, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.quantization import QuantizerRole import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( +from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + current_scaling_quantizer_factory, + mxfp8_quantizer_factory, + float8_block_scaling_quantizer_factory, + nvfp4_quantizer_factory, + delayed_scaling_quantizer_factory, +) +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import ( nvfp4_ref_rht_2d_quantizer_factory, ) @@ -91,9 +99,9 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -128,9 +136,9 @@ def test_custom_recipe_grouped_linear_sanity(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -190,9 +198,9 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -210,7 +218,7 @@ def quantizer_factory(role): assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 - assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E4M3 # role=None fallback loss_custom = (out_custom.float() * scale.view(1, -1)).sum() loss_custom.backward() @@ -247,9 +255,9 @@ def test_custom_recipe_ops_linear_2_1_layout(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -277,39 +285,40 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): op = Linear(in_features, out_features, params_dtype=torch.bfloat16) inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) - # Counters per role + # Counters per tensor_type. The output (fwd) and grad_input (bwd) + # slots have role=None by default (unknown consumer), so we count + # those separately. counts = { - "linear_input": 0, - "linear_weight": 0, - "linear_output": 0, - "linear_grad_output": 0, - "linear_grad_input": 0, + "input": 0, + "weight": 0, + "grad_output": 0, + None: 0, } def quantizer_factory(role): - if role in counts: - counts[role] += 1 - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: + counts[None] += 1 return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if role in ("linear_grad_output", "linear_grad_input"): + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + assert role.module_type == "linear" + if role.tensor_type in counts: + counts[role.tensor_type] += 1 + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) custom = recipe.CustomRecipe(qfactory=quantizer_factory) - # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), - # and backward to build 2 quantizers (cycled from 1 factory). with autocast(enabled=True, recipe=custom): out = op(inp) loss = out.float().sum() loss.backward() - # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input - assert counts["linear_input"] == 1 - assert counts["linear_weight"] == 1 - assert counts["linear_output"] == 1 - assert counts["linear_grad_output"] == 1 - assert counts["linear_grad_input"] == 1 + # Forward: input, weight, output(None); backward: grad_output, grad_input(None) + assert counts["input"] == 1 + assert counts["weight"] == 1 + assert counts["grad_output"] == 1 + assert counts[None] == 2, f"Expected 2 None roles (output + grad_input), got {counts[None]}" def test_factories_return_distinct_instances_and_buffers(): @@ -331,3 +340,616 @@ def factory(): # Mutating one should not affect the other q1.scale.fill_(123.0) assert not torch.equal(q1.scale, q2.scale) + + +def _run_linear_fwd_bwd(model, inp, recipe): + """Run forward + backward with a given recipe and return (output, inp.grad, param grads).""" + with autocast(enabled=True, recipe=recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + param_grads = {n: p.grad.clone() for n, p in model.named_parameters() if p.grad is not None} + return out.clone(), inp.grad.clone(), param_grads + + +def _make_pair(in_features=128, out_features=128, batch=32, seed=42): + """Create a pair of identical Linear models and matching inputs.""" + torch.manual_seed(seed) + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus.load_state_dict(model_ref.state_dict()) + + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_cus = base_inp.clone().detach().requires_grad_(True) + return model_ref, model_cus, inp_ref, inp_cus + + +def _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus): + """Assert exact match of outputs and all gradients.""" + assert torch.allclose( + out_ref, out_cus, rtol=0.0, atol=0.0 + ), f"Forward mismatch: max diff = {(out_ref - out_cus).abs().max()}" + assert torch.allclose( + grad_ref, grad_cus, rtol=0.0, atol=0.0 + ), f"Input grad mismatch: max diff = {(grad_ref - grad_cus).abs().max()}" + for name in pgrads_ref: + assert torch.allclose(pgrads_ref[name], pgrads_cus[name], rtol=0.0, atol=0.0), ( + f"Param grad '{name}' mismatch: max diff = " + f"{(pgrads_ref[name] - pgrads_cus[name]).abs().max()}" + ) + + +def test_factory_matches_delayed_scaling(): + """delayed_scaling_quantizer_factory should produce bit-identical results + to the built-in DelayedScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd(model_ref, inp_ref, recipe.DelayedScaling()) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=delayed_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_current_scaling(): + """current_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8CurrentScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8CurrentScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=current_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_mxfp8(): + """mxfp8_quantizer_factory should produce bit-identical results + to the built-in MXFP8BlockScaling recipe.""" + available, reason = te.is_mxfp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"MXFP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.MXFP8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=mxfp8_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_block_scaling(): + """float8_block_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8BlockScaling recipe.""" + available = te.is_fp8_block_scaling_available() + if not torch.cuda.is_available() or not available: + pytest.skip("Float8 block scaling unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=float8_block_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_nvfp4(): + """nvfp4_quantizer_factory should produce bit-identical results + to the built-in NVFP4BlockScaling recipe.""" + available = te.is_nvfp4_available() + if not torch.cuda.is_available() or not available: + pytest.skip("NVFP4 unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.NVFP4BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=nvfp4_quantizer_factory) + ) + + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_custom_recipe_quantization_targets(): + """Validate fine-grained per-module quantization targeting via QuantizerRole. + + Four transformer layers, each assembled at a different abstraction level. + The default recipe is NVFP4; specific modules are overridden: + + Layer 0 - ``TransformerLayer`` (name="tl0") -> all MXFP8 + Layer 1 - ``TransformerLayer`` (name="tl1") -> NVFP4 (default), + except fc2 overridden to MXFP8 + Layer 2 - ``MultiheadAttention`` + ``LayerNormMLP`` + (name prefix "tl2") -> NVFP4 (default), + except qkv and fc1 overridden to Float8 block-scaling + Layer 3 - Individual blocks (name prefix "tl3") -> NVFP4 (default), + except proj overridden to Float8 current-scaling + + The test validates that: + * The factory receives QuantizerRole objects with correct names + * Different quantizer types are dispatched per module + * Forward + backward complete successfully through all four layers + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_mxfp8_available(): + pytest.skip("MXFP8 unsupported on this device") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + if not te.is_fp8_block_scaling_available(): + pytest.skip("Float8 block scaling unsupported on this device") + + torch.manual_seed(42) + + H = 64 # hidden_size + FFN = 64 # ffn_hidden_size + NH = 4 # num_heads + KV = H // NH # kv_channels + B = 4 # batch + S = 8 # seq_len + common = dict(params_dtype=torch.bfloat16, bias=False) + + # Layer 0: TransformerLayer -> MXFP8 + tl0 = te.TransformerLayer( + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl0", + **common, + ).cuda() + + # Layer 1: TransformerLayer -> NVFP4 default, fc2 overridden to MXFP8 + tl1 = te.TransformerLayer( + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl1", + **common, + ).cuda() + + # Layer 2: MHA + LayerNormMLP -> NVFP4 default, qkv and fc1 to block-scaling + tl2_mha = te.MultiheadAttention( + H, + NH, + KV, + attention_dropout=0.0, + input_layernorm=True, + return_bias=True, + name="tl2.self_attention", + **common, + ).cuda() + tl2_mlp = LayerNormMLP(H, FFN, name="tl2.layernorm_mlp", **common).cuda() + + # Layer 3: Individual blocks with DPA -> NVFP4 default, proj to current-scaling + tl3_qkv = LayerNormLinear(H, 3 * H, name="tl3.qkv", **common).cuda() + tl3_dpa = te.DotProductAttention(NH, KV, attention_dropout=0.0, name="tl3.core_attention") + tl3_proj = Linear(H, H, name="tl3.proj", **common).cuda() + tl3_fc1 = LayerNormLinear(H, FFN, name="tl3.fc1", **common).cuda() + tl3_fc2 = Linear(FFN, H, name="tl3.fc2", **common).cuda() + + # ------------------------------------------------------------------ + # Recording + dispatching factory + # ------------------------------------------------------------------ + recorded_roles = [] + + def targeting_factory(role): + recorded_roles.append(role) + + if role is None: + return nvfp4_quantizer_factory(role) + + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + + # Layer 0 (tl0.*): all MXFP8 + if role.name.startswith("tl0"): + return mxfp8_quantizer_factory(role) + + # Layer 1 (tl1.*): NVFP4 default, but fc2 overridden to MXFP8 + if role.name == "tl1.layernorm_mlp.fc2": + return mxfp8_quantizer_factory(role) + + # Layer 2: block scaling for qkv and fc1, rest falls through to default + if role.name == "tl2.self_attention.layernorm_linear_qkv": + return float8_block_scaling_quantizer_factory(role) + if role.name == "tl2.layernorm_mlp.fc1": + return float8_block_scaling_quantizer_factory(role) + + # Layer 3: current-scaling for proj, rest falls through to default + if role.name == "tl3.proj": + return current_scaling_quantizer_factory(role) + + # Default: NVFP4 + return nvfp4_quantizer_factory(role) + + custom_recipe = recipe.CustomRecipe(qfactory=targeting_factory) + + # ------------------------------------------------------------------ + # Forward + backward + # ------------------------------------------------------------------ + inp = torch.randn(S, B, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + # Layer 0 & 1: TransformerLayer + h = tl1(tl0(inp)) + + # Layer 2: MHA + residual + LayerNormMLP + residual + attn_out, _ = tl2_mha(h) + h = h + attn_out + h = h + tl2_mlp(h) + + # Layer 3: individual blocks with DPA + residual = h + qkv = tl3_qkv(h).view(S, B, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn = tl3_dpa(q, k, v).view(S, B, H) + h = residual + tl3_proj(attn) + residual = h + h = residual + tl3_fc2(torch.nn.functional.gelu(tl3_fc1(h))) + + loss = h.float().sum() + loss.backward() + + # ------------------------------------------------------------------ + # Assertions + # ------------------------------------------------------------------ + + assert inp.grad is not None, "Input gradient is None" + + # -- Name propagation check -- + # The factory dispatches on role.name, so if a TE module fails to propagate + # names (e.g. TransformerLayer -> MHA -> LayerNormLinear) the factory would + # silently fall through to the default recipe. The quantizer-type assertions + # below would catch that too, but checking names explicitly gives a clearer + # error message pointing at the broken name rather than a wrong quantizer type. + role_names = {r.name for r in recorded_roles if r is not None} + + def _tl_names(prefix): + """Expected role names for a standard TransformerLayer with given prefix.""" + return { + f"{prefix}.self_attention.layernorm_linear_qkv", + f"{prefix}.self_attention.proj", + f"{prefix}.layernorm_mlp.fc1", + f"{prefix}.layernorm_mlp.fc2", + } + + all_expected = ( + _tl_names("tl0") + | _tl_names("tl1") + | _tl_names("tl2") + | {"tl3.qkv", "tl3.proj", "tl3.fc1", "tl3.fc2"} + ) + missing = all_expected - role_names + assert not missing, ( + f"Expected module names not seen in QuantizerRole.name: {missing}\n" + f"Recorded names: {sorted(role_names)}" + ) + + for r in recorded_roles: + if r is not None and r.module_type: + assert r.module_type in ( + "linear", + "dpa", + ), f"Unexpected module_type={r.module_type} for role {r}" + + # -- Quantizer-type checks -- + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer + + def _check_q(mod, expected_cls, label=""): + q = mod.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + assert isinstance(q, expected_cls), ( + f"{mod.name}{' (' + label + ')' if label else ''}: " + f"expected {expected_cls.__name__}, got {type(q).__name__}" + ) + + # Layer 0: all MXFP8 + _check_q(tl0.self_attention.layernorm_qkv, MXFP8Quantizer) + _check_q(tl0.self_attention.proj, MXFP8Quantizer) + + # Layer 1: NVFP4 default, fc2 overridden to MXFP8 + _check_q(tl1.self_attention.layernorm_qkv, NVFP4Quantizer, "default") + _check_q(tl1.self_attention.proj, NVFP4Quantizer, "default") + assert any( + r is not None and r.name == "tl1.layernorm_mlp.fc2" and r.tensor_type == "input" + for r in recorded_roles + ), "tl1.layernorm_mlp.fc2 input role not recorded" + + # Layer 2: block-scaling on qkv and fc1, NVFP4 on proj and fc2 + _check_q(tl2_mha.layernorm_qkv, Float8BlockQuantizer) + _check_q(tl2_mha.proj, NVFP4Quantizer, "default") + + # Layer 3: current-scaling on proj, NVFP4 on everything else + _check_q(tl3_proj, Float8CurrentScalingQuantizer) + for mod in [tl3_qkv, tl3_fc1, tl3_fc2]: + _check_q(mod, NVFP4Quantizer, "default") + + +def test_delayed_scaling_request_wiring(): + """Shared buffers, correct views, Float8Quantizer instances.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID, amax_history_len=16) + + custom_recipe = recipe.CustomRecipe(qfactory=ds_factory) + + # 3 quantizers (input, weight, output) like a Linear fwd + state = CustomRecipeState(custom_recipe, mode="forward", num_quantizers=3) + state.roles = [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + quantizers = state.make_quantizers() + + # All quantizers should be Float8Quantizer + assert len(quantizers) == 3 + for q in quantizers: + assert isinstance(q, Float8Quantizer), f"Expected Float8Quantizer, got {type(q).__name__}" + + # Managed state should exist + assert state._has_delayed_scaling + assert state.scale is not None + assert state.amax_history is not None + + # Shared buffers: scale shape = (3,), amax_history shape = (16, 3) + assert state.scale.shape == (3,) + assert state.amax_history.shape == (16, 3) + + # Each quantizer's scale should be a view into the shared buffer + for i, q in enumerate(quantizers): + assert q.scale.data_ptr() == state.scale[i].data_ptr() + + # Each quantizer's amax should be a view into amax_history[0] + for i, q in enumerate(quantizers): + assert q.amax.data_ptr() == state.amax_history[0][i].reshape((1,)).data_ptr() + + # Inner recipe should be a DelayedScaling + inner = state._inner_delayed_scaling_recipe + assert isinstance(inner, recipe.DelayedScaling) + assert inner.amax_history_len == 16 + assert inner.fp8_format == Format.HYBRID + + +def test_custom_recipe_mixed_ds_and_stateless(): + """Mix DelayedScalingRequest + stateless quantizers in same CustomRecipeState.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + from transformer_engine.common.recipe import Format + + def mixed_factory(role): + # Only weight gets delayed scaling, rest get current scaling + if role is not None and role.tensor_type == "weight": + return DelayedScalingRequest(fp8_format=Format.HYBRID) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=mixed_factory) + + # 3 quantizers: input(current), weight(DS), output(current) + state = CustomRecipeState(custom_recipe, mode="forward", num_quantizers=3) + state.roles = [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + + # Slot 0 (input): current scaling + assert isinstance(quantizers[0], Float8CurrentScalingQuantizer) + # Slot 1 (weight): delayed scaling + assert isinstance(quantizers[1], Float8Quantizer) + # Slot 2 (output): current scaling + assert isinstance(quantizers[2], Float8CurrentScalingQuantizer) + + # Only 1 DS request => shared buffers have size 1 + assert state._has_delayed_scaling + assert state.scale.shape == (1,) + assert state.amax_history.shape == (1024, 1) + + +def test_custom_recipe_ds_multi_step(): + """amax_history updates across multiple forward steps.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID) + + in_features = 128 + out_features = 128 + batch = 32 + num_steps = 3 + + torch.manual_seed(99) + model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + custom = recipe.CustomRecipe(qfactory=ds_factory) + + amax_snapshots = [] + for step in range(num_steps): + inp = torch.randn( + batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + with autocast(enabled=True, recipe=custom): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Capture amax_history snapshot + fwd_state = model.fp8_meta["scaling_fwd"] + amax_snapshots.append(fwd_state.amax_history.clone()) + + # After 3 steps, amax_history should have been updated at least once + # The first row (amax_history[0]) should differ from the initial zeros + # after the first step + assert not torch.all(amax_snapshots[0] == 0), "amax_history should be updated after first step" + + +def test_custom_recipe_dpa_fp8(): + """DotProductAttention forward+backward with CustomRecipe and role-based mixed quantizers. + + Uses the nvfp4_linear_fp8_dpa_factory which dispatches: + * DPA S/dP slots -> DelayedScalingRequest (stateful) + * DPA QKV/O/dO/dQKV slots -> Float8CurrentScalingQuantizer + * Linear slots -> NVFP4Quantizer + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.utils import get_device_compute_capability + + cc = get_device_compute_capability() + if cc < (9, 0) or cc >= (12, 0): + pytest.skip(f"FP8 attention not supported on sm{cc[0]*10+cc[1]}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, + ) + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_fp8_dpa_factory, + ) + + torch.manual_seed(42) + + H = 64 + NH = 4 + KV = H // NH + B = 2 + S = 32 + + # Build a small model: Linear -> DPA -> Linear + qkv_proj = Linear(H, 3 * H, params_dtype=torch.bfloat16, bias=False, name="qkv").cuda() + dpa = te.DotProductAttention( + NH, KV, attention_dropout=0.0, qkv_format="bshd", name="core_attention" + ) + out_proj = Linear(H, H, params_dtype=torch.bfloat16, bias=False, name="proj").cuda() + + custom_recipe = recipe.CustomRecipe( + qfactory=nvfp4_linear_fp8_dpa_factory, + fp8_dpa=True, + ) + + inp = torch.randn(B, S, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + qkv = qkv_proj(inp).view(B, S, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn_out = dpa(q, k, v, qkv_format="bshd").reshape(B, S, H) + out = out_proj(attn_out) + + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None, "Input gradient should exist" + + # Verify DPA recipe state is CustomRecipeState + fwd_state = dpa.fp8_meta["scaling_fwd"] + assert isinstance( + fwd_state, CustomRecipeState + ), f"Expected CustomRecipeState for DPA fwd, got {type(fwd_state).__name__}" + + # Verify DPA quantizers: 9 forward slots (3 GEMMs x 3) + fwd_quantizers = dpa.quantizers["scaling_fwd"] + assert len(fwd_quantizers) == 9, f"Expected 9 fwd quantizers, got {len(fwd_quantizers)}" + + # Slots 0-2: QKV (GEMM1) -> current scaling (role: module_type="dpa") + # Slots 3-5: O (GEMM2) -> current scaling (role: name hint "dpa_output") + # Slots 6-8: S (GEMM3) -> delayed scaling (Float8Quantizer from DelayedScalingRequest) + for i in range(6): + assert isinstance(fwd_quantizers[i], Float8CurrentScalingQuantizer), ( + f"Slot {i} (QKV/O): expected Float8CurrentScalingQuantizer, " + f"got {type(fwd_quantizers[i]).__name__}" + ) + for i in range(6, 9): + assert isinstance(fwd_quantizers[i], Float8Quantizer), ( + f"Slot {i} (S): expected Float8Quantizer (delayed scaling), " + f"got {type(fwd_quantizers[i]).__name__}" + ) + + # Verify DS state exists for the S/dP delayed scaling requests + assert fwd_state._has_delayed_scaling, "DPA fwd state should have delayed scaling for S slots" + + # Verify backward quantizers exist too + bwd_quantizers = dpa.quantizers["scaling_bwd"] + assert len(bwd_quantizers) == 6, f"Expected 6 bwd quantizers, got {len(bwd_quantizers)}" + + # Slots 0-1: dQKV (GEMM1) -> current scaling (role: name hint "dpa_grad_input") + # Slots 2-3: dO (GEMM2) -> current scaling (role: module_type="dpa") + # Slots 4-5: dP (GEMM3) -> delayed scaling + for i in range(4): + assert isinstance(bwd_quantizers[i], Float8CurrentScalingQuantizer), ( + f"Bwd slot {i} (dQKV/dO): expected Float8CurrentScalingQuantizer, " + f"got {type(bwd_quantizers[i]).__name__}" + ) + for i in range(4, 6): + assert isinstance(bwd_quantizers[i], Float8Quantizer), ( + f"Bwd slot {i} (dP): expected Float8Quantizer (delayed scaling), " + f"got {type(bwd_quantizers[i]).__name__}" + ) + + # Linear modules should have CustomRecipeState with NVFP4 quantizers + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + qkv_fwd = qkv_proj.fp8_meta["scaling_fwd"] + assert isinstance( + qkv_fwd, CustomRecipeState + ), f"Expected CustomRecipeState for qkv_proj, got {type(qkv_fwd).__name__}" + qkv_fwd_quantizers = qkv_proj.quantizers["scaling_fwd"] + for i, q in enumerate(qkv_fwd_quantizers): + if q is not None: + assert isinstance( + q, NVFP4Quantizer + ), f"qkv_proj fwd slot {i}: expected NVFP4Quantizer, got {type(q).__name__}" diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 99ab9c4984..3b964a5af9 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -14,7 +14,7 @@ from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.custom_recipes.quantization import MMParams -from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import ( +from transformer_engine.pytorch.custom_recipes.quantization_ref_current_scaling import ( CurrentScalingQuantizerRef, ) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 18577b0eb4..c6cae2efcb 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -499,23 +499,41 @@ class CustomRecipe(Recipe): Parameters ---------- qfactory : Callable - Factory callable that returns a quantizer instance for a - given semantic tensor role. - The callable is typically invoked as:: + Factory callable that returns a quantizer instance *or* a + ``QuantizerRequest`` subclass for a given ``QuantizerRole``. + The callable is invoked as:: qfactory( - role: str, - ) - - Where `role` is one of the following strings for e.g. te.Linear - (stable public contract): - - - forward: "linear_input", "linear_weight", "linear_output" - - backward: "linear_grad_output", "linear_grad_input" + role: QuantizerRole, + ) -> Union[Quantizer, QuantizerRequest] + + ``QuantizerRole`` is a frozen dataclass with the following fields: + + - ``module_type`` (str): module type (empty string when not set), e.g. + ``"linear"``, ``"grouped_linear"``, ``"dpa"``. + - ``tensor_type`` (str): what tensor is being quantized (empty + string when not set), e.g. ``"input"``, ``"weight"``, ``"grad_output"``. + - ``name`` (str): caller-provided module instance name (empty + string when not set), e.g. ``"qkv"``, ``"proj"``, ``"fc1"``, ``"fc2"``. + + For stateful quantizers (delayed scaling), return a + ``DelayedScalingRequest`` dataclass instead of a quantizer. + TE will allocate shared scale/amax_history buffers and create + ``Float8Quantizer`` instances integrated with the existing + delayed-scaling reduction infrastructure. + + See ``transformer_engine.pytorch.quantization.QuantizerRole`` + and ``transformer_engine.pytorch.quantization.DelayedScalingRequest`` + for full documentation. """ qfactory: Callable[..., Any] + # fp8_format does not affect quantization (quantization factory controls that), + # but TE internals (e.g. get_fp8_te_dtype, backend selection) read it + # from the recipe. HYBRID (E4M3 fwd, E5M2 bwd) is a safe default. + fp8_format: Format = Format.HYBRID + fp8_dpa: bool = False fp8_mha: bool = False diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954b..44907dd658 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -48,6 +48,9 @@ from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe +from transformer_engine.pytorch.quantization import QuantizerRole +from transformer_engine.pytorch.quantization import QuantizerRequest +from transformer_engine.pytorch.quantization import DelayedScalingRequest from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import is_bf16_available diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..091621fa76 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -22,6 +22,7 @@ ) from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.quantization import ( + QuantizerRole, get_fp8_te_dtype, FP8GlobalStateManager, RecipeState, @@ -284,6 +285,8 @@ class DotProductAttention(TransformerEngineBaseModule): `_). :math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``, and :math:`\text{max_logit}` is of shape ``[h]``. + name : Optional[str], default = None + module instance name. Parallelism parameters ---------------------- @@ -343,8 +346,9 @@ def __init__( softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name=name) self.logger = logging.getLogger("DotProductAttention") self.logger.setLevel(attn_log._log_level) @@ -584,6 +588,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_recipe.custom(): + super().init_fp8_metadata(num_gemms=num_gemms) return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to @@ -756,6 +761,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe) and recipe.custom(): + return TransformerEngineBaseModule.set_meta_tensor(self, fwd, recipe) if isinstance(recipe, Recipe): recipe = [recipe] fp8_recipe_dpa = recipe[-1] @@ -802,6 +809,96 @@ def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> Non for recipe_state in recipe_states: self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``DotProductAttention``. + + DPA internally performs two matmuls:: + + S = softmax(Q · K^T) (GEMM1) + O = S · V (GEMM2) + + cuDNN's fused-attention API exposes FP8 scale/amax descriptors as + a flat array of **slot groups** numbered 1-3. The numbering is a + cuDNN convention — it does *not* correspond to operation order + inside DPA: + + Forward (3 slot groups × 3 positions = 9 slots): + + =========== =========================================== =========== + Slot group Primary tensor cuDNN enum + =========== =========================================== =========== + Group 1 QKV — inputs to GEMM1 (Q·K^T) GEMM1_OUTPUT + Group 2 O — output of GEMM2 (S·V) GEMM2_INPUT + Group 3 S — post-softmax, input to GEMM2 (S·V) GEMM3_OUTPUT + =========== =========================================== =========== + + Backward (3 slot groups × 2 positions = 6 slots): + + =========== =========================================== =========== + Slot group Primary tensor cuDNN enum + =========== =========================================== =========== + Group 1 dQKV — gradients flowing back to Q, K, V GRAD_OUTPUT1 + Group 2 dO — gradient of the attention output GRAD_INPUT2 + Group 3 dP — gradient of the softmax output GRAD_INPUT3 + =========== =========================================== =========== + + Unused positions within a group share the role of the group's + primary tensor. + + **Boundary slots** — O (fwd) and dQKV (bwd) leave DPA and enter + the next module (e.g. proj linear). DPA does not know that + consumer, so these default to ``None``. The parent module + (e.g. ``MultiheadAttention``) can set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to fill in the consumer identity. + + When not set, a hint-only ``QuantizerRole`` with empty + ``module_type`` / ``tensor_type`` is emitted, with ``name`` + containing ``"dpa_output"`` or ``"dpa_grad_input"``. This lets + the factory return a DPA-compatible quantizer (required by the + fused kernel) even when the downstream consumer is unknown. + """ + name = self.name or "" + if fwd: + qkv_role = QuantizerRole(module_type="dpa", tensor_type="qkv", name=name) + o_role = self._output_quantizer_role + if o_role is None: + o_role = QuantizerRole(name=f"{name}.dpa_output" if name else "dpa_output") + s_role = QuantizerRole(module_type="dpa", tensor_type="s", name=name) + base = [ + qkv_role, + qkv_role, + qkv_role, # Group 1: QKV (inputs to Q·K^T) + o_role, + o_role, + o_role, # Group 2: O (output of S·V) — boundary + s_role, + s_role, + s_role, # Group 3: S (post-softmax, input to S·V) + ] + else: + dqkv_role = self._grad_input_quantizer_role + if dqkv_role is None: + dqkv_role = QuantizerRole( + name=f"{name}.dpa_grad_input" if name else "dpa_grad_input" + ) + do_role = QuantizerRole(module_type="dpa", tensor_type="do", name=name) + dp_role = QuantizerRole(module_type="dpa", tensor_type="dp", name=name) + base = [ + dqkv_role, + dqkv_role, # Group 1: dQKV (grads to Q,K,V) — boundary + do_role, + do_role, # Group 2: dO (grad of attention output) + dp_role, + dp_role, # Group 3: dP (grad of softmax output) + ] + return base[:num_quantizers] + @no_torch_dynamo(recursive=False) def forward( self, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..8ad2376e3e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2118,6 +2118,24 @@ def get_attention_quantizers(fp8, quantizers): dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + _fp8_types = (Float8Quantizer, Float8CurrentScalingQuantizer) + for _name, _q in [ + ("QKV", QKV_quantizer), + ("O", O_quantizer), + ("S", S_quantizer), + ("dQKV", dQKV_quantizer), + ("dO", dO_quantizer), + ("dP", dP_quantizer), + ]: + assert isinstance(_q, _fp8_types), ( + "FP8 attention requires FP8-compatible quantizers for all DPA tensor slots, " + f"but {_name} quantizer is {type(_q).__name__}. " + "When using CustomRecipe with fp8_dpa=True, ensure the factory returns an " + "FP8 quantizer (Float8Quantizer or Float8CurrentScalingQuantizer) for all " + "DPA roles (module_type='dpa') and for None roles (boundary slots like " + "O output and dQKV grad-input)." + ) + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..1607d513b1 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, QuantizerRole from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm @@ -461,6 +461,7 @@ def __init__( layer_number=self.layer_number, attention_type=self.attention_type, softmax_type=self.softmax_type, + name=name + ".core_attention" if name is not None else None, ) # Linear @@ -478,6 +479,84 @@ def __init__( **common_gemm_kwargs, ) + def _update_output_quantizer_roles( + self, + qkv_fp8_output: bool, + proj_fp8_grad: bool, + dpa_fp8_output: bool, + ) -> None: + """Set quantizer roles at the boundaries between QKV, DPA, and proj. + + MHA contains three submodules connected as follows:: + + Forward: QKV linear ──(QKV tensor)──> DPA ──(O tensor)──> Proj linear + Backward: QKV linear <──(dQKV tensor)── DPA <──(dO tensor)── Proj linear + + Each submodule owns quantizers for its internal tensors, but the + *boundary* tensors (the arrows above) need to know which module + will *consume* them so the quantizer factory can pick the right + format. This method sets those boundary roles on all four edges: + + 1. ``qkv_fp8_output`` — **QKV linear → DPA (fwd)**: the QKV + linear's ``output_quantizer_role`` is told its consumer is DPA. + 2. ``proj_fp8_grad`` — **Proj linear ← DPA (bwd)**: proj's + ``grad_input_quantizer_role`` is told its producer is DPA. + 3. ``dpa_fp8_output`` — **DPA → Proj linear (fwd)**: DPA's + ``output_quantizer_role`` is told its consumer is the proj linear. + 4. ``dpa_fp8_output`` — **DPA ← QKV linear (bwd)**: DPA's + ``grad_input_quantizer_role`` is told its consumer is QKV linear. + + When a flag is ``False`` the corresponding role is reset to ``None`` + so the module falls back to its own default. + """ + dpa_name = self.core_attention.name or "" + + # ── Boundary 1 (fwd): QKV linear output → consumed by DPA ──────── + qkv_output_role = ( + QuantizerRole(module_type="dpa", tensor_type="qkv", name=dpa_name) + if qkv_fp8_output + else None + ) + if self.attention_type == "self": + if self.input_layernorm: + self.layernorm_qkv.output_quantizer_role = qkv_output_role + else: + self.qkv.output_quantizer_role = qkv_output_role + elif self.attention_type == "cross": + if self.input_layernorm: + self.layernorm_query.output_quantizer_role = qkv_output_role + else: + self.query_layer.output_quantizer_role = qkv_output_role + self.key_value.output_quantizer_role = qkv_output_role + + # ── Boundary 2 (bwd): Proj grad-input ← produced by DPA ────────── + proj_grad_input_role = ( + QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) + if proj_fp8_grad + else None + ) + self.proj.grad_input_quantizer_role = proj_grad_input_role + + # ── Boundary 3 (fwd): DPA output (O) → consumed by Proj linear ─── + proj_name = self.proj.name or "" + self.core_attention.output_quantizer_role = ( + QuantizerRole(module_type="linear", tensor_type="input", name=proj_name) + if dpa_fp8_output + else None + ) + + # ── Boundary 4 (bwd): DPA grad-input (dQKV) → consumed by QKV linear + if self.attention_type == "self": + qkv_linear = self.layernorm_qkv if self.input_layernorm else self.qkv + else: + qkv_linear = self.layernorm_query if self.input_layernorm else self.query_layer + qkv_name = qkv_linear.name or "" + self.core_attention.grad_input_quantizer_role = ( + QuantizerRole(module_type="linear", tensor_type="grad_output", name=qkv_name) + if dpa_fp8_output + else None + ) + def fast_setattr(self, name: str, value: Any) -> None: """Fast attribute set for non-parameter fields.""" self.__dict__[name] = value @@ -806,6 +885,8 @@ def forward( # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling + self._update_output_quantizer_roles(qkv_fp8_output, proj_fp8_grad, dpa_fp8_output) + layernorm_output = None if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] diff --git a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py new file mode 100644 index 0000000000..6f5fead4e3 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples. + +Demonstrates how to use the ``CustomRecipe`` + ``qfactory`` interface to apply +*different* quantization recipes to different module/tensor types/instances within the same model. + +Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_mxfp8_grouped_linear_factory, + nvfp4_linear_fp8_dpa_factory, + ) + + # Mixed module types: NVFP4 for Linear, MXFP8 for GroupedLinear + recipe = CustomRecipe(qfactory=nvfp4_linear_mxfp8_grouped_linear_factory) + with autocast(recipe=recipe): + output = model(input) + + # NVFP4 for Linear, FP8 current-scaling + delayed-scaling for DPA + recipe = CustomRecipe(qfactory=nvfp4_linear_fp8_dpa_factory, fp8_dpa=True) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def nvfp4_linear_mxfp8_grouped_linear_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, MXFP8 for ``GroupedLinear``. + + Dispatch logic: + * ``role.module_type == "grouped_linear"`` -> MXFP8 (E4M3, block-32) + * everything else (``"linear"`` or unknown) -> NVFP4 (E2M1) + + NVFP4 settings follow the built-in ``NVFP4BlockScaling`` defaults: + * Weights: 2D quantization (16x16), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + """ + is_grouped_linear = role is not None and role.module_type == "grouped_linear" + + if is_grouped_linear: + return _make_mxfp8_quantizer() + + return _make_nvfp4_quantizer(role) + + +def _make_mxfp8_quantizer(): + """Return an MXFP8 quantizer with default settings (E4M3, block-32, E8M0 scales).""" + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): + """Return an NVFP4 quantizer configured per tensor role. + + Mirrors :class:`NVFP4BlockScaling` recipe defaults. + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type == "linear" + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + +def nvfp4_linear_fp8_dpa_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, mixed FP8 for ``DotProductAttention``. + + This factory demonstrates how to use ``CustomRecipe`` with ``fp8_dpa=True`` + to combine NVFP4 quantization for linear layers with FP8 attention. + + DPA tensor types (``role.module_type == "dpa"``): + + =========== ============================================================ + tensor_type Description + =========== ============================================================ + ``"qkv"`` Query, Key, Value inputs to the first attention GEMM + ``"s"`` Softmax output (S = softmax(Q·K^T)), fed into the second GEMM + ``"o"`` Attention output (O = S·V) + ``"do"`` Gradient of the attention output (dO), backward input + ``"dp"`` Gradient of the softmax output (dP = dO·V^T), backward + ``"dqkv"`` Gradient flowing back to Q, K, V + =========== ============================================================ + + Dispatch logic: + * ``role.module_type == "dpa"`` with ``tensor_type in ("s", "dp")`` + -> FP8 delayed scaling (stateful amax tracking) + * ``role.module_type == "dpa"`` (QKV, dO) + -> FP8 current scaling (E4M3) + * DPA boundary hints (``"dpa_output"`` / ``"dpa_grad_input"`` in ``role.name``) + -> FP8 current scaling placeholder. The fused attention kernel requires + FP8-compatible quantizers in all DPA slots, even when the output is + produced in BF16 (``fp8_mha=False``). DPA emits these hint-only roles + (with empty ``module_type`` and ``tensor_type``) when the downstream + consumer is unknown. + * everything else (``"linear"`` / ``"grouped_linear"`` / ``None``) + -> NVFP4 (E2M1), configured per tensor role + + Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_fp8_dpa_factory, + ) + + recipe = CustomRecipe( + qfactory=nvfp4_linear_fp8_dpa_factory, + fp8_dpa=True, + ) + with autocast(recipe=recipe): + output = model(input) + """ + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer + + is_dpa = role is not None and role.module_type == "dpa" + is_softmax_or_dp = is_dpa and role.tensor_type in ("s", "dp") + + if is_softmax_or_dp: + return DelayedScalingRequest() + + if is_dpa: + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + + # DPA boundary slots (O output / dQKV grad-input): the fused attention + # kernel only supports FP8 quantizers here, regardless of the linear recipe. + is_dpa_boundary = ( + role is not None + and not role.module_type + and ("dpa_output" in role.name or "dpa_grad_input" in role.name) + ) + if is_dpa_boundary: + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + + return _make_nvfp4_quantizer(role) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py new file mode 100644 index 0000000000..c79f0a4b9b --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples using real silicon quantizers. + +Each factory below replicates the behaviour of built-in TE recipe but via the +``CustomRecipe`` + ``qfactory`` interface. This is useful when you want to +start from a known-good recipe and then selectively override quantizer settings +for specific layers / tensor types. + +Usage (any factory):: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + nvfp4_quantizer_factory, + ) + + recipe = CustomRecipe(qfactory=nvfp4_quantizer_factory) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def delayed_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "DelayedScalingRequest": + """Factory that mirrors :class:`DelayedScaling` recipe defaults. + + Returns a :class:`DelayedScalingRequest` for every slot. TE allocates + shared scale/amax_history buffers and wires them into the existing + delayed-scaling reduction path. + + * HYBRID format: E4M3 forward, E5M2 backward + * amax_history_len = 1024 + * reduce_amax = True + """ + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.common.recipe import Format + + return DelayedScalingRequest(fp8_format=Format.HYBRID) + + +def current_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8CurrentScalingQuantizer": + """Factory that mirrors :class:`Float8CurrentScaling` recipe defaults. + + * Forward tensors (input, weight) → E4M3 + * Backward tensors (grad_output) → E5M2 + """ + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + is_backward = role is not None and role.tensor_type == "grad_output" + fp8_dtype = tex.DType.kFloat8E5M2 if is_backward else tex.DType.kFloat8E4M3 + + return Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + force_pow_2_scales=False, # constrain scale to powers of 2 + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + ) + + +def mxfp8_quantizer_factory( + role: Optional[QuantizerRole], +) -> "MXFP8Quantizer": + """Factory that mirrors :class:`MXFP8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Block size 32, power-of-2 (E8M0) scales + """ + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def float8_block_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8BlockQuantizer": + """Factory that mirrors :class:`Float8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Weights use 2D block scaling, everything else uses 1D + * Power-of-2 scales by default + """ + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + + is_weight = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + block_scaling_dim = 2 if is_weight else 1 + + return Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + force_pow_2_scales=True, + block_scaling_dim=block_scaling_dim, # 1 = 1D (1×128), 2 = 2D (128×128) + ) + + +def nvfp4_quantizer_factory( + role: Optional[QuantizerRole], +) -> "NVFP4Quantizer": + """Factory that mirrors :class:`NVFP4BlockScaling` recipe defaults. + + * All tensors quantized to E2M1 (FP4) + * Weights: 2D quantization (16x16 blocks), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + + Quantizer knobs: + fp4_dtype - E2M1 (only supported format) + with_rht - randomized Hadamard transform (smooths outliers) + with_post_rht_amax - recompute amax after RHT (should match with_rht) + with_2d_quantization - 16x16 2D blocks (vs 1x16 1D) + stochastic_rounding - probabilistic rounding to reduce quant bias + with_random_sign_mask - random sign flip in the Hadamard matrix + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type in ("linear", "grouped_linear") + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + # For input and unknown roles + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py similarity index 98% rename from transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py index 5bdc537e4b..0034b739cb 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py @@ -18,17 +18,18 @@ def current_scaling_ref_quantizer_factory(role): """Factory function for current scaling reference quantizer. - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Backward tensors use E5M2, everything else uses E4M3. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory) with autocast(recipe=custom_recipe): output = model(input) """ - if role in ("linear_input", "linear_weight"): - dtype = torch.float8_e4m3fn - elif role in ("linear_output", "linear_grad_output"): - dtype = torch.float8_e5m2 - else: - return None + is_backward = role is not None and role.tensor_type == "grad_output" + dtype = torch.float8_e5m2 if is_backward else torch.float8_e4m3fn return CurrentScalingQuantizerRef( dtype=dtype, rowwise=True, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py similarity index 98% rename from transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index d00d0c8b94..0b29977fb4 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -18,33 +18,32 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): """ Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) - with autocast(fp8_recipe=custom_recipe): + with autocast(recipe=custom_recipe): output = model(input) """ - if role == "linear_input": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - if role == "linear_weight": + is_weight_tensor_in_gemm = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + if is_weight_tensor_in_gemm: # 2D quantization for weights in GEMM-based modules return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if role == "linear_grad_output": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - return None + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) def cast_to_fp4x2(x): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9c21141a39..5b00e58501 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -27,8 +27,11 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, NVFP4BlockScalingRecipeState, + CustomRecipeState, FP8GlobalStateManager, + QuantizerRole, RecipeState, + _has_delayed_scaling_state, ) from ..distributed import ( gather_along_first_dim, @@ -631,6 +634,8 @@ def __init__(self, name: Optional[str] = None) -> None: self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_store = None + self._output_quantizer_role: Optional[QuantizerRole] = None + self._grad_input_quantizer_role: Optional[QuantizerRole] = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -651,6 +656,72 @@ def module_setattr(self, name: str, value: Any) -> None: """ super().__setattr__(name, value) + @property + def output_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the forward output quantizer. + + When set, overrides the default role used by :meth:`get_quantizer_roles` + for the forward-pass output quantizer slot. Setting this after + quantizers have been created forces their recreation on the next + forward pass. + + See also :attr:`grad_input_quantizer_role` for the backward-pass + counterpart. + """ + return self._output_quantizer_role + + @output_quantizer_role.setter + def output_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._output_quantizer_role: + return + self._output_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + @property + def grad_input_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the grad-input quantizer. + + Backward-pass counterpart of :attr:`output_quantizer_role`. + """ + return self._grad_input_quantizer_role + + @grad_input_quantizer_role.setter + def grad_input_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._grad_input_quantizer_role: + return + self._grad_input_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + def _warn_missing_output_quantizer_role( + self, + fp8_output: bool, + fp8_grad: bool, + ) -> None: + """Warn when quantized output is requested but no consumer role is set. + + Only relevant for ``CustomRecipe`` where the ``qfactory`` dispatches + on roles. Built-in recipes ignore role metadata. + """ + recipe = FP8GlobalStateManager.get_fp8_recipe() + if not recipe.custom(): + return + if fp8_output and self._output_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_output=True but " + "output_quantizer_role is not set. The CustomRecipe qfactory " + "will receive None for the output quantizer role.", + stacklevel=3, + ) + if fp8_grad and self._grad_input_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_grad=True but " + "grad_input_quantizer_role is not set. The CustomRecipe " + "qfactory will receive None for the grad-input quantizer role.", + stacklevel=3, + ) + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ Delayed scaling only. @@ -726,21 +797,104 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return + if recipe.custom() and isinstance(recipe_state, CustomRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 # Initialize recipe state and quantizers + roles = self.get_quantizer_roles(fwd=fwd, num_quantizers=num_fp8_tensors) + if roles is not None: + assert ( + len(roles) == num_fp8_tensors + ), f"Recipe roles must match number of quantizers ({len(roles)=} vs {num_fp8_tensors=})" recipe_state = RecipeState.create( recipe, mode=("forward" if fwd else "backward"), num_quantizers=num_fp8_tensors, + roles=roles, ) self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. + + Overview + -------- + When using ``CustomRecipe``, the quantizer factory is called once + per quantizer slot. Each call receives a ``QuantizerRole`` that + tells the factory *what* is being quantized so it can return the + right quantizer. + + This method builds the role list. Subclasses override it to + describe their internal GEMM layout. + + Slot layout + ----------- + Return one ``QuantizerRole`` per slot, in the same order as the + module's quantizer array. For example, ``Linear`` uses 3 + forward slots ``[input, weight, output]`` and 2 backward slots + ``[grad_output, grad_input]``. Multi-GEMM modules like + ``LayerNormMLP`` repeat that pattern for each GEMM: + ``[fc1_input, fc1_weight, fc1_output, fc2_input, fc2_weight, fc2_output]``. + + What to put in each slot + ------------------------ + Create a ``QuantizerRole(module_type=..., tensor_type=..., + name=...)`` for each slot: + + * **module_type** — the kind of module: ``"linear"``, + ``"grouped_linear"``, ``"dpa"``, etc. The factory can dispatch + on this to use different quantization formats per module type. + * **tensor_type** — what tensor this slot holds, in the module's + own vocabulary. For linears: ``"input"``, ``"weight"``, + ``"grad_output"``, etc. For DPA: ``"qkv"``, ``"s"``, + ``"do"``, ``"dp"``, etc. + * **name** — the instance name (e.g. ``"encoder.layer0.fc1"``), + propagated from the ``name=`` constructor argument. The factory + can dispatch on this to target specific layers. + + Boundary slots + -------------- + The last slot of a forward GEMM group (output) and the last slot + of a backward group (grad_input) are **boundary** slots — the + tensor leaves this module and enters an unknown consumer. For + these slots, use ``self._output_quantizer_role`` (fwd) and + ``self._grad_input_quantizer_role`` (bwd), which default to + ``None``. A parent module (e.g. ``MultiheadAttention``) can set + these attributes to fill in the consumer identity; see + ``MultiheadAttention._update_output_quantizer_roles`` for an + example. + + Return value + ------------ + * A list of ``QuantizerRole`` with length ``num_quantizers``. + * ``None`` to opt out of role-based dispatch. + + Not implemented (default) + ~~~~~~~~~~~~~~~~~~~~~~~~~ + The base implementation returns ``None``. When ``None`` is + returned, ``CustomRecipeState`` emits a warning and falls back + to bare ``QuantizerRole()`` instances (all fields empty) for + every slot. The factory still gets called once per slot, but + every call receives an identical empty role — it cannot + distinguish input from weight, forward from backward, or one + module from another. What happens then depends entirely on the + factory: it may return the same quantizer for all slots (acting + as a uniform recipe), or it may raise an error if it requires + meaningful roles to dispatch on. + """ + return None + def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() @@ -844,7 +998,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Copy tensors to CPU and store state = {} state["recipe"] = self.fp8_meta["recipe"] - if state["recipe"].delayed(): + if _has_delayed_scaling_state(self.fp8_meta): state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) @@ -916,7 +1070,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load tensors - if self.fp8_meta["recipe"].delayed(): + if _has_delayed_scaling_state(self.fp8_meta): copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) @@ -1042,7 +1196,7 @@ def prepare_forward( # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): - delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = _has_delayed_scaling_state(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -1054,10 +1208,12 @@ def prepare_forward( self.init_fp8_metadata(num_gemms=num_gemms) self._check_weight_tensor_recipe_correspondence() - delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = self.fp8 and _has_delayed_scaling_state(self.fp8_meta) if delayed_scaling_recipe: if self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( + assert self.fp8_meta["recipe"].reduce_amax or ( + self.fp8_meta["recipe"].custom() + ), ( "Amax reduction across tensor parallel group is " "necessary when using sequence parallelism with FP8." ) @@ -1079,7 +1235,7 @@ def end_forward(self): Required to be called at the end of the forward function to properly handle DelayedScaling metadata handling and the NVTX ranges. """ - delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = self.fp8 and _has_delayed_scaling_state(self.fp8_meta) if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) nvtx_range_pop() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..09817e14be 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -22,7 +22,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, cast_if_needed, @@ -106,7 +106,18 @@ def forward( # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): - raise ValueError("DelayedScaling recipe is not supported with save_original_input") + if module.fp8_meta["recipe"].custom(): + # Custom recipe factory may produce DS quantizers unknown to caller. + # TODO(negvet): fix on Megatron side — guard should also exclude 'custom', or + # better: check at runtime whether quantizers are DS-based. + warnings.warn( + "save_original_input is incompatible with delayed-scaling quantizers " + "(Float8Quantizer). Disabling save_original_input for this module.", + stacklevel=2, + ) + save_original_input = False + else: + raise ValueError("DelayedScaling recipe is not supported with save_original_input") if input_quantizers[0] is not None: for input_quantizer in input_quantizers: input_quantizer.set_usage( @@ -749,6 +760,33 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``GroupedLinear``. + + For grouped GEMMs we repeat the same pattern for each GEMM in + order. The output (fwd) and grad-input (bwd) slots default to + ``None`` (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="input", name=name), + QuantizerRole(module_type="grouped_linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def make_grouped_weights(self, defer_init=False) -> None: """ Convert parameters into a GroupedTensor and re-register them as parameters. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ce0581024a..78e05e94ca 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -26,7 +26,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, assert_dim_for_all_gather, @@ -1413,6 +1413,32 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormLinear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1611,6 +1637,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 16e620fd94..071460f7b0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -27,7 +27,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -1980,6 +1980,44 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormMLP``. + + Each internal GEMM (fc1, fc2) gets a distinct name suffix so that + custom-recipe factories can target them individually. + + The module's final output (fc2 fwd) and final grad (fc1 bwd) + slots default to ``None`` (unknown consumer). Set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to provide consumer identity. Internal boundaries use fixed + roles with known consumer identity. + """ + base_name = self.name or "" + fc1_name = f"{base_name}.fc1" if base_name else "fc1" + fc2_name = f"{base_name}.fc2" if base_name else "fc2" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + self._grad_input_quantizer_role, + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -2187,6 +2225,9 @@ def forward( return out def _get_quantizers(self, fp8_output, is_grad_enabled): + if self.fp8: + self._warn_missing_output_quantizer_role(fp8_output, False) + ( fc1_input_quantizer, fc1_output_quantizer, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 31dac4d329..2254d8ff91 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,7 +25,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, clear_tensor_data, @@ -176,10 +176,22 @@ def forward( if fp8: assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) - if save_original_input: - assert not isinstance( - input_quantizer, Float8Quantizer - ), "DelayedScaling recipe is not supported with save_original_input" + if save_original_input and isinstance(input_quantizer, Float8Quantizer): + if module.fp8_meta["recipe"].custom(): + # Custom recipe factory may produce DS quantizers unknown to caller. + # TODO(negvet): fix on Megatron side — guard in attention.py checks + # `fp8_recipe != 'delayed'` but should also exclude 'custom', or + # better: check at runtime whether quantizers are DS-based. + warnings.warn( + "save_original_input is incompatible with delayed-scaling quantizers " + "(Float8Quantizer). Disabling save_original_input for this module.", + stacklevel=2, + ) + save_original_input = False + else: + raise AssertionError( + "DelayedScaling recipe is not supported with save_original_input" + ) if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor @@ -1309,6 +1321,32 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``Linear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1479,6 +1517,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 48376a297f..dfb1b7f741 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,7 +19,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...quantization import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -270,6 +270,21 @@ def num_quantizers(self, mode: str) -> int: return 1 return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + name = getattr(self, "name", "") or "" + if mode == "forward": + # BasicLinear owns input and weight quantizers. + # Output quantizer is provided by the next op (as its input quantizer). + return [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + ] + if mode == "backward": + # BasicLinear owns grad_output quantizer. + # Grad_input quantizer is provided by the previous op (as its grad_output quantizer). + return [QuantizerRole(module_type="linear", tensor_type="grad_output", name=name)] + return None + def reset_parameters(self) -> None: """Initialize parameter buffers and values""" diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 54b3f00117..3a59e3b229 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Recipe from ..quantization import ( FP8GlobalStateManager, + QuantizerRole, RecipeState, autocast, ) @@ -209,6 +210,15 @@ def num_quantizers( """ return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. + + The returned list must be aligned with the internal quantizer ordering and + must have length ``num_quantizers(mode)`` for supported modes. + Returning ``None`` means "no explicit roles". + """ + return None + def get_input_quantizer(self) -> Optional[Quantizer]: if self.num_quantizers("forward") > 0: return self.get_quantizer("forward", 0) @@ -268,10 +278,17 @@ def reset_recipe_state( ) # Construct quantization recipe state + roles = self.get_quantizer_roles(mode) + if roles is not None: + assert len(roles) == num_quantizers, ( + "Recipe roles must match number of quantizers " + f"({len(roles)=} vs {num_quantizers=})" + ) recipe_state = RecipeState.create( recipe, mode=mode, num_quantizers=num_quantizers, + roles=roles, ) fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=(mode == "forward"), diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..c7157dc54d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import dataclasses import itertools import functools import warnings @@ -41,9 +42,105 @@ "is_nvfp4_available", "get_default_recipe", "get_align_size_for_quantization", + "QuantizerRole", + "QuantizerRequest", + "DelayedScalingRequest", ] +@dataclasses.dataclass(frozen=True) +class QuantizerRole: + """Identity of a tensor slot requesting a quantizer. + + TE modules populate all fields they know about. + User factories inspect only the fields they care about. + + .. warning:: + **EXPERIMENTAL**: QuantizerRole is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + + Fields + ------ + module_type : str + Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. + Empty string when not provided. + tensor_type : str + What tensor is being quantized, in the module's own vocabulary. + Linear modules: `"input"`, `"weight"`, `"grad_output"`, etc. + DPA: `"qkv"`, `"s"`, etc. + Empty string when not provided. + name : str + Caller-provided module instance name (e.g. set by the training + framework), e.g. + `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. + Empty string when not provided. + """ + + module_type: str = "" + tensor_type: str = "" + name: str = "" + + def __str__(self) -> str: + parts = [] + if self.module_type: + parts.append(f"module_type={self.module_type}") + if self.tensor_type: + parts.append(f"tensor_type={self.tensor_type}") + if self.name: + parts.append(f"name={self.name}") + return "|".join(parts) if parts else "QuantizerRole()" + + +@dataclasses.dataclass(frozen=True) +class QuantizerRequest: + """Base class for stateful quantizer requests. + + Custom recipe factories return ``QuantizerRequest`` subclasses (instead of + quantizer instances) when the quantizer requires TE-managed shared state. + TE detects these requests, allocates the required state, and replaces them + with real quantizer instances. + + .. warning:: + **EXPERIMENTAL**: QuantizerRequest is experimental, still under active + development, and the API is subject to change without notice. + """ + + +@dataclasses.dataclass(frozen=True) +class DelayedScalingRequest(QuantizerRequest): + """Request a Float8Quantizer with TE-managed delayed scaling state. + + .. warning:: + **EXPERIMENTAL**: DelayedScalingRequest is experimental, still under active + development, and the API is subject to change without notice. + + All ``DelayedScalingRequest`` instances within the same ``CustomRecipeState`` + must share identical parameter values. + + Parameters + ---------- + fp8_format : Format, default = Format.HYBRID + Controls fwd/bwd dtype (HYBRID = E4M3 fwd, E5M2 bwd). + margin : int, default = 0 + Margin for scaling factor computation. + amax_history_len : int, default = 1024 + Length of the amax history window. + amax_compute_algo : str or Callable, default = "max" + Algorithm for choosing amax from history. + scaling_factor_compute_algo : Callable or None, default = None + Custom scaling factor computation. + reduce_amax : bool, default = True + Whether to all-reduce amax across the distributed group. + """ + + fp8_format: Format = Format.HYBRID + margin: int = 0 + amax_history_len: int = 1024 + amax_compute_algo: Union[str, Callable] = "max" + scaling_factor_compute_algo: Optional[Callable] = None + reduce_amax: bool = True + + @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" @@ -362,7 +459,7 @@ def add_fp8_tensors_to_global_buffer( fp8_meta: Dict[str, Any], ) -> None: """ - Delayed scaling only. + Delayed scaling only (built-in or custom recipe with DS requests). The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is @@ -377,8 +474,8 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + # noop unless delayed scaling state is present + if not _has_delayed_scaling_state(fp8_meta): return # Every module must call this function exactly once since @@ -395,18 +492,30 @@ def add_fp8_tensors_to_global_buffer( # Handles non-parameter FP8 modules, e.g. DPA. continue - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) + state = fp8_meta[fp8_meta_tensor_key] + + # Determine recipe + buffers: built-in DS or custom with DS requests + if isinstance(state, CustomRecipeState) and state._has_delayed_scaling: + inner_recipe = state._inner_delayed_scaling_recipe + amax_hist = state.amax_history + scale_tensor = state.scale + key = cls.get_key_in_buffer(forward, inner_recipe, fp8_meta["fp8_group"]) + # Register inner recipe in autocast_arguments for reduction + autocast_key = cls.get_unique_autocast_key(inner_recipe, fp8_meta["fp8_group"]) + cls.autocast_arguments[autocast_key] = (inner_recipe, fp8_meta["fp8_group"]) + else: + amax_hist = state.amax_history + scale_tensor = state.scale + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + cls.global_amax_buffer[key] = [amax_hist[0]] + cls.global_amax_history_buffer[key] = [amax_hist] + cls.global_scale_buffer[key] = [scale_tensor] else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history - ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + cls.global_amax_buffer[key].append(amax_hist[0]) + cls.global_amax_history_buffer[key].append(amax_hist) + cls.global_scale_buffer[key].append(scale_tensor) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @@ -616,7 +725,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - """ # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -642,7 +751,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non 1 forward for indentical numerical outputs. """ # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return # Store updated amaxes and scales from phase 1 post forward. @@ -661,7 +770,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) @@ -992,6 +1101,7 @@ def create( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[list[QuantizerRole]] = None, ) -> RecipeState: """Factory method to create the state for a quantization recipe @@ -1005,6 +1115,8 @@ def create( Number of quantizers to create state for. device: torch.device, default = default CUDA device Device for quantized tensors. + roles: list of QuantizerRole, optional + Semantic roles for each quantizer slot. Returns ------- @@ -1028,12 +1140,15 @@ def create( cls = CustomRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( + state = cls( recipe, mode=mode, num_quantizers=num_quantizers, device=device, ) + # Optional QuantizerRole objects + state.roles = roles + return state @abc.abstractmethod def make_quantizers(self) -> list: @@ -1353,13 +1468,95 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: raise RuntimeError(f"Unexpected recipe mode ({self.mode})") +def _handle_delayed_scaling_requests( + raw: list, + device: torch.device, + mode: str, +) -> Optional["DelayedScalingRecipeState"]: + """Detect DelayedScalingRequest items, allocate shared state, replace with real quantizers. + + All DS requests in the same RecipeState must share identical parameters. + + Returns a ``DelayedScalingRecipeState`` owning the shared buffers, or + ``None`` when no DS requests are present. + """ + ds_items = [(i, r) for i, r in enumerate(raw) if isinstance(r, DelayedScalingRequest)] + if not ds_items: + return None + + r0 = ds_items[0][1] + + # Validate all DS requests share same params + for idx, req in ds_items[1:]: + for field_name in ( + "fp8_format", + "margin", + "amax_history_len", + "amax_compute_algo", + "scaling_factor_compute_algo", + "reduce_amax", + ): + v0 = getattr(r0, field_name) + vi = getattr(req, field_name) + if v0 != vi: + raise ValueError( + "All DelayedScalingRequests in one CustomRecipeState must match. " + f"Slot 0 has {field_name}={v0!r}, slot {idx} has {vi!r}." + ) + + # Build a real DelayedScalingRecipeState to own the shared buffers. + inner_recipe = DelayedScaling( + fp8_format=r0.fp8_format, + margin=r0.margin, + amax_history_len=r0.amax_history_len, + amax_compute_algo=r0.amax_compute_algo, + scaling_factor_compute_algo=r0.scaling_factor_compute_algo, + reduce_amax=r0.reduce_amax, + ) + n = len(ds_items) + dsrs = DelayedScalingRecipeState( + inner_recipe, + mode=mode, + num_quantizers=n, + device=device, + ) + + # Splice Float8Quantizer instances (backed by dsrs buffers) into raw list. + quantizers = dsrs.make_quantizers() + for j, (idx, _req) in enumerate(ds_items): + raw[idx] = quantizers[j] + + return dsrs + + +def _has_delayed_scaling_state(fp8_meta: Dict[str, Any]) -> bool: + """Check if fp8_meta has delayed scaling state (built-in or custom).""" + if fp8_meta["recipe"].delayed(): + return True + if fp8_meta["recipe"].custom(): + for key in ("scaling_fwd", "scaling_bwd"): + state = fp8_meta.get(key) + if isinstance(state, CustomRecipeState) and state._has_delayed_scaling: + return True + return False + + class CustomRecipeState(RecipeState): - """State for CustomRecipe: produce quantizers per tensor.""" + """State for CustomRecipe: produce quantizers per tensor. + + Stateful quantizer support: + - Supports stateful quantizers (e.g. delayed scaling) via ``DelayedScalingRequest``. + - The factory returns request dataclasses for stateful quantizers; TE detects them, + allocates shared buffers, and replaces with real quantizer instances. + - Stateful recipe state is composed via real TE recipe state objects (e.g. + ``DelayedScalingRecipeState``), not reimplemented. + """ recipe: CustomRecipe mode: str num_quantizers: int device: Optional[torch.device] + _ds_state: Optional[DelayedScalingRecipeState] def __init__( self, @@ -1375,32 +1572,47 @@ def __init__( if device is None: device = torch.device("cuda") self.device = device + self._ds_state = None if getattr(recipe, "qfactory", None) is None: raise ValueError("CustomRecipe requires `qfactory`.") def make_quantizers(self) -> list: qfactory = self.recipe.qfactory - out = [] - # TODO(negvet): make_quantizers() should take roles from the operation - # Hardcode linear-specific roles for now - roles: List[str] - if self.mode == "forward": - roles = [ - ("linear_input", "linear_weight", "linear_output")[i % 3] - for i in range(self.num_quantizers) - ] - elif self.mode == "backward": - roles = [ - ("linear_grad_output", "linear_grad_input")[i % 2] - for i in range(self.num_quantizers) - ] - else: - roles = ["unknown"] * self.num_quantizers + roles: List[QuantizerRole] = getattr(self, "roles", None) + if roles is None: + warnings.warn( + "CustomRecipeState: no QuantizerRole list provided by the module/op. " + "Falling back to bare QuantizerRole() defaults. " + "Override get_quantizer_roles() to provide meaningful roles.", + stacklevel=2, + ) + roles = [QuantizerRole() for _ in range(self.num_quantizers)] + if len(roles) != self.num_quantizers: + raise ValueError( + "CustomRecipeState requires roles to match num_quantizers " + f"({len(roles)=} vs {self.num_quantizers=})" + ) + + raw = [qfactory(roles[i]) for i in range(self.num_quantizers)] + self._ds_state = _handle_delayed_scaling_requests(raw, self.device, self.mode) + return raw + + # -- Delegation to composed DelayedScalingRecipeState -- + + @property + def _has_delayed_scaling(self) -> bool: + return self._ds_state is not None + + @property + def amax_history(self) -> Optional[torch.Tensor]: + return self._ds_state.amax_history if self._ds_state else None + + @property + def scale(self) -> Optional[torch.Tensor]: + return self._ds_state.scale if self._ds_state else None - for i in range(self.num_quantizers): - # Get quantizer from the user defined factory - quantizer = qfactory(roles[i]) - out.append(quantizer) - return out + @property + def _inner_delayed_scaling_recipe(self) -> Optional[DelayedScaling]: + return self._ds_state.recipe if self._ds_state else None