From cd8b8ad598cdf98c30f3ecf2720ccc230e417b2b Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 23 Jan 2026 15:14:21 +0000 Subject: [PATCH 01/36] Enable semantic roles emitted by module/op and comsumed by custom recipe state Signed-off-by: Evgeny --- transformer_engine/common/recipe/__init__.py | 4 +-- transformer_engine/pytorch/module/base.py | 20 +++++++++++++ .../pytorch/module/grouped_linear.py | 16 ++++++++++ .../pytorch/module/layernorm_linear.py | 20 +++++++++++++ .../pytorch/module/layernorm_mlp.py | 13 +++++++++ transformer_engine/pytorch/module/linear.py | 13 +++++++++ .../pytorch/ops/basic/basic_linear.py | 11 +++++++ transformer_engine/pytorch/ops/op.py | 16 ++++++++++ transformer_engine/pytorch/quantization.py | 29 +++++++++---------- 9 files changed, 125 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..ef30a4a7e6 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -503,8 +503,8 @@ class CustomRecipe(Recipe): Where `role` is one of the following strings for e.g. te.Linear (stable public contract): - - forward: "linear_input", "linear_weight", "linear_output" - - backward: "linear_grad_output", "linear_grad_input" + - forward: "input:linear", "weight:linear", "output:linear" + - backward: "grad_output:linear", "grad_input:linear" """ qfactory: Callable[..., Any] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..8e5500694d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -732,15 +732,35 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 # Initialize recipe state and quantizers + roles = self.get_quantizer_roles(fwd=fwd, num_quantizers=num_fp8_tensors) + if roles is not None: + assert len(roles) == num_fp8_tensors, ( + "Recipe roles must match number of quantizers " + f"({len(roles)=} vs {num_fp8_tensors=})" + ) recipe_state = RecipeState.create( recipe, mode=("forward" if fwd else "backward"), num_quantizers=num_fp8_tensors, + roles=roles, ) self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[str]]: + """Return an ordered list of role strings for quantizers. + + The returned list must have length `num_quantizers`. + Returning `None` means "no explicit roles". + """ + return None + def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..7569fb21ca 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -724,6 +724,22 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[str]]: + """Role strings for quantizers used by `GroupedLinear`. + + For grouped GEMMs we repeat the same pattern for each GEMM in order. + """ + if fwd: + base = ("input:grouped_linear", "weight:grouped_linear", "output:grouped_linear") + else: + base = ("grad_output:grouped_linear", "grad_input:grouped_linear") + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..ab6a43e472 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1412,6 +1412,26 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[str]]: + """Role strings for quantizers used by `LayerNormLinear`.""" + if fwd: + base = ( + "input:layernorm_linear", + "weight:layernorm_linear", + "output:layernorm_linear", + ) + else: + base = ( + "grad_output:layernorm_linear", + "grad_input:layernorm_linear", + ) + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..ac22bc442e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1968,6 +1968,19 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[str]]: + """Role strings for quantizers used by `LayerNormMLP`.""" + if fwd: + base = ("input:layernorm_mlp", "weight:layernorm_mlp", "output:layernorm_mlp") + else: + base = ("grad_output:layernorm_mlp", "grad_input:layernorm_mlp") + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..0ec6ca8cde 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1308,6 +1308,19 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[str]]: + """Role strings for quantizers used by `Linear`.""" + if fwd: + base = ("input:linear", "weight:linear", "output:linear") + else: + base = ("grad_output:linear", "grad_input:linear") + return [base[i % len(base)] for i in range(num_quantizers)] + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb1..d0755982be 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -270,6 +270,17 @@ def num_quantizers(self, mode: str) -> int: return 1 return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[str]]: + if mode == "forward": + # BasicLinear owns input and weight quantizers. + # Output quantizer is provided by the next op (as its input quantizer). + return ["input:linear", "weight:linear"] + if mode == "backward": + # BasicLinear owns grad_output quantizer. + # Grad_input quantizer is provided by the previous op (as its grad_output quantizer). + return ["grad_output:linear"] + return None + def reset_parameters(self) -> None: """Initialize parameter buffers and values""" diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 47286dfced..5c3fafe6e8 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -209,6 +209,15 @@ def num_quantizers( """ return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[str]]: + """Return an ordered list of role strings for quantizers. + + The returned list must be aligned with the internal quantizer ordering and + must have length `num_quantizers(mode)` for supported modes. + Returning `None` means "no explicit roles". + """ + return None + def get_input_quantizer(self) -> Optional[Quantizer]: if self.num_quantizers("forward") > 0: return self.get_quantizer("forward", 0) @@ -268,10 +277,17 @@ def reset_recipe_state( ) # Construct quantization recipe state + roles = self.get_quantizer_roles(mode) + if roles is not None: + assert len(roles) == num_quantizers, ( + "Recipe roles must match number of quantizers " + f"({len(roles)=} vs {num_quantizers=})" + ) recipe_state = RecipeState.create( recipe, mode=mode, num_quantizers=num_quantizers, + roles=roles, ) fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=(mode == "forward"), diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..06f8cf1321 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -992,6 +992,7 @@ def create( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[list[str]] = None, ) -> RecipeState: """Factory method to create the state for a quantization recipe @@ -1028,12 +1029,16 @@ def create( cls = CustomRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( + state = cls( recipe, mode=mode, num_quantizers=num_quantizers, device=device, ) + # Optional role strings for quantizers, now only used only by CustomRecipe. + # TODO(negvet): Make all recipe states take roles in their constructors. + state.roles = roles + return state @abc.abstractmethod def make_quantizers(self) -> list: @@ -1383,21 +1388,15 @@ def make_quantizers(self) -> list: qfactory = self.recipe.qfactory out = [] - # TODO(negvet): make_quantizers() should take roles from the operation - # Hardcode linear-specific roles for now roles: List[str] - if self.mode == "forward": - roles = [ - ("linear_input", "linear_weight", "linear_output")[i % 3] - for i in range(self.num_quantizers) - ] - elif self.mode == "backward": - roles = [ - ("linear_grad_output", "linear_grad_input")[i % 2] - for i in range(self.num_quantizers) - ] - else: - roles = ["unknown"] * self.num_quantizers + if getattr(self, "roles", None) is None: + raise ValueError("CustomRecipeState requires roles to be set.") + roles = self.roles + if len(roles) != self.num_quantizers: + raise ValueError( + "CustomRecipeState requires roles to match num_quantizers " + f"({len(roles)=} vs {self.num_quantizers=})" + ) for i in range(self.num_quantizers): # Get quantizer from the user defined factory From fddeba48d735bc70b75ed6a110a3c1be65f14f71 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 23 Jan 2026 15:15:33 +0000 Subject: [PATCH 02/36] Update quantization factories Signed-off-by: Evgeny --- .../custom_recipes/quantization_current_scaling.py | 8 ++++++-- .../pytorch/custom_recipes/quantization_nvfp4.py | 10 +++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 5bdc537e4b..f8b7ccce5b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -23,9 +23,13 @@ def current_scaling_ref_quantizer_factory(role): with autocast(recipe=custom_recipe): output = model(input) """ - if role in ("linear_input", "linear_weight"): + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + + if bucket in ("input", "weight"): dtype = torch.float8_e4m3fn - elif role in ("linear_output", "linear_grad_output"): + elif bucket in ("output", "grad_output"): dtype = torch.float8_e5m2 else: return None diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d00d0c8b94..b180603966 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -23,21 +23,25 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): with autocast(fp8_recipe=custom_recipe): output = model(input) """ - if role == "linear_input": + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + + if bucket == "input": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, ) - if role == "linear_weight": + if bucket == "weight": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if role == "linear_grad_output": + if bucket == "grad_output": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), From 82b84ffa6340bbdbf295f152f518828d37af1e46 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 23 Jan 2026 15:18:59 +0000 Subject: [PATCH 03/36] Fix tests Signed-off-by: Evgeny --- .../pytorch/distributed/run_numerics_exact.py | 13 +++-- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 13 +++-- tests/pytorch/test_custom_recipe.py | 55 ++++++++++++------- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 0f3d2cbbf0..eb413affe3 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -60,31 +60,34 @@ def get_nvfp4_quantizer_factory(): """ def factory(role): - if role == "linear_input": + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, # RHT enabled for input ) - elif role == "linear_weight": + elif bucket == "weight": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), # 2D quantization for weight pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": + elif bucket == "output": # Output quantization not used return None - elif role == "linear_grad_output": + elif bucket == "grad_output": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, # RHT enabled for grad_output ) - elif role == "linear_grad_input": + elif bucket == "grad_input": # Grad input quantization not used return None else: diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index a96fea3af0..a4c776d616 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -80,31 +80,34 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo """ def factory(role): - if role == "linear_input": + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif role == "linear_weight": + elif bucket == "weight": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": + elif bucket == "output": # Output quantization not used return None - elif role == "linear_grad_output": + elif bucket == "grad_output": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif role == "linear_grad_input": + elif bucket == "grad_input": # Grad input quantization not used return None else: diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 4de49115b3..f1d259d0b9 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -90,9 +90,12 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if bucket in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -127,9 +130,12 @@ def test_custom_recipe_grouped_linear_sanity(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if bucket in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -189,9 +195,12 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if bucket in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -246,9 +255,12 @@ def test_custom_recipe_ops_linear_2_1_layout(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if bucket in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -278,19 +290,22 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): # Counters per role counts = { - "linear_input": 0, - "linear_weight": 0, - "linear_output": 0, - "linear_grad_output": 0, - "linear_grad_input": 0, + "input:linear": 0, + "weight:linear": 0, + "output:linear": 0, + "grad_output:linear": 0, + "grad_input:linear": 0, } def quantizer_factory(role): if role in counts: counts[role] += 1 - if role in ("linear_input", "linear_weight", "linear_output"): + if ":" not in role: + raise ValueError(f"Invalid role: {role}, expected format: ':'") + bucket, _ = role.split(":", 1) + if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if role in ("linear_grad_output", "linear_grad_input"): + if bucket in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) @@ -304,11 +319,11 @@ def quantizer_factory(role): loss.backward() # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input - assert counts["linear_input"] == 1 - assert counts["linear_weight"] == 1 - assert counts["linear_output"] == 1 - assert counts["linear_grad_output"] == 1 - assert counts["linear_grad_input"] == 1 + assert counts["input:linear"] == 1 + assert counts["weight:linear"] == 1 + assert counts["output:linear"] == 1 + assert counts["grad_output:linear"] == 1 + assert counts["grad_input:linear"] == 1 def test_factories_return_distinct_instances_and_buffers(): From 434623149ec491e0fa5eff29a25d90b446061f3f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:32:07 +0000 Subject: [PATCH 04/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 9 ++++----- transformer_engine/pytorch/ops/op.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8e5500694d..2961916266 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -734,10 +734,9 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: # Initialize recipe state and quantizers roles = self.get_quantizer_roles(fwd=fwd, num_quantizers=num_fp8_tensors) if roles is not None: - assert len(roles) == num_fp8_tensors, ( - "Recipe roles must match number of quantizers " - f"({len(roles)=} vs {num_fp8_tensors=})" - ) + assert ( + len(roles) == num_fp8_tensors + ), f"Recipe roles must match number of quantizers ({len(roles)=} vs {num_fp8_tensors=})" recipe_state = RecipeState.create( recipe, mode=("forward" if fwd else "backward"), @@ -756,7 +755,7 @@ def get_quantizer_roles( ) -> Optional[List[str]]: """Return an ordered list of role strings for quantizers. - The returned list must have length `num_quantizers`. + The returned list must have length `num_quantizers`. Returning `None` means "no explicit roles". """ return None diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 5c3fafe6e8..1bd0c92b79 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -213,7 +213,7 @@ def get_quantizer_roles(self, mode: str) -> Optional[list[str]]: """Return an ordered list of role strings for quantizers. The returned list must be aligned with the internal quantizer ordering and - must have length `num_quantizers(mode)` for supported modes. + must have length `num_quantizers(mode)` for supported modes. Returning `None` means "no explicit roles". """ return None From a81f54a645a3e60a852bf1bceb707b36735d114d Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 27 Jan 2026 10:57:10 -0800 Subject: [PATCH 05/36] Swap tensor:module Signed-off-by: Evgeny --- .../pytorch/distributed/run_numerics_exact.py | 4 +- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 4 +- tests/pytorch/test_custom_recipe.py | 40 +++++++++---------- transformer_engine/common/recipe/__init__.py | 4 +- .../quantization_current_scaling.py | 4 +- .../custom_recipes/quantization_nvfp4.py | 4 +- .../pytorch/module/grouped_linear.py | 4 +- .../pytorch/module/layernorm_linear.py | 10 ++--- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 4 +- .../pytorch/ops/basic/basic_linear.py | 4 +- 11 files changed, 43 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index eb413affe3..6e1e0c8deb 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -61,8 +61,8 @@ def get_nvfp4_quantizer_factory(): def factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index a4c776d616..6ee1a969dd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -81,8 +81,8 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo def factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index f1d259d0b9..9462050ec5 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -91,8 +91,8 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") if bucket in ("grad_output", "grad_input"): @@ -131,8 +131,8 @@ def test_custom_recipe_grouped_linear_sanity(): def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") if bucket in ("grad_output", "grad_input"): @@ -196,8 +196,8 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") if bucket in ("grad_output", "grad_input"): @@ -256,8 +256,8 @@ def test_custom_recipe_ops_linear_2_1_layout(): def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") if bucket in ("grad_output", "grad_input"): @@ -290,19 +290,19 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): # Counters per role counts = { - "input:linear": 0, - "weight:linear": 0, - "output:linear": 0, - "grad_output:linear": 0, - "grad_input:linear": 0, + "linear:input": 0, + "linear:weight": 0, + "linear:output": 0, + "linear:grad_output": 0, + "linear:grad_input": 0, } def quantizer_factory(role): if role in counts: counts[role] += 1 if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) if bucket in ("grad_output", "grad_input"): @@ -319,11 +319,11 @@ def quantizer_factory(role): loss.backward() # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input - assert counts["input:linear"] == 1 - assert counts["weight:linear"] == 1 - assert counts["output:linear"] == 1 - assert counts["grad_output:linear"] == 1 - assert counts["grad_input:linear"] == 1 + assert counts["linear:input"] == 1 + assert counts["linear:weight"] == 1 + assert counts["linear:output"] == 1 + assert counts["linear:grad_output"] == 1 + assert counts["linear:grad_input"] == 1 def test_factories_return_distinct_instances_and_buffers(): diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ef30a4a7e6..2c6b171e9f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -503,8 +503,8 @@ class CustomRecipe(Recipe): Where `role` is one of the following strings for e.g. te.Linear (stable public contract): - - forward: "input:linear", "weight:linear", "output:linear" - - backward: "grad_output:linear", "grad_input:linear" + - forward: "linear:input", "linear:weight", "linear:output" + - backward: "linear:grad_output", "linear:grad_input" """ qfactory: Callable[..., Any] diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index f8b7ccce5b..ebda402ec8 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -24,8 +24,8 @@ def current_scaling_ref_quantizer_factory(role): output = model(input) """ if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket in ("input", "weight"): dtype = torch.float8_e4m3fn diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index b180603966..23d6d6e589 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -24,8 +24,8 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): output = model(input) """ if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - bucket, _ = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, bucket = role.split(":", 1) if bucket == "input": return NVFP4QuantizerRef( diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7569fb21ca..bfaffdf2ed 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -735,9 +735,9 @@ def get_quantizer_roles( For grouped GEMMs we repeat the same pattern for each GEMM in order. """ if fwd: - base = ("input:grouped_linear", "weight:grouped_linear", "output:grouped_linear") + base = ("grouped_linear:input", "grouped_linear:weight", "grouped_linear:output") else: - base = ("grad_output:grouped_linear", "grad_input:grouped_linear") + base = ("grouped_linear:grad_output", "grouped_linear:grad_input") return [base[i % len(base)] for i in range(num_quantizers)] def reset_parameters(self, defer_init=False): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ab6a43e472..98db95011e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1421,14 +1421,14 @@ def get_quantizer_roles( """Role strings for quantizers used by `LayerNormLinear`.""" if fwd: base = ( - "input:layernorm_linear", - "weight:layernorm_linear", - "output:layernorm_linear", + "layernorm_linear:input", + "layernorm_linear:weight", + "layernorm_linear:output", ) else: base = ( - "grad_output:layernorm_linear", - "grad_input:layernorm_linear", + "layernorm_linear:grad_output", + "layernorm_linear:grad_input", ) return [base[i % len(base)] for i in range(num_quantizers)] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ac22bc442e..33f4849fa2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1976,9 +1976,9 @@ def get_quantizer_roles( ) -> Optional[List[str]]: """Role strings for quantizers used by `LayerNormMLP`.""" if fwd: - base = ("input:layernorm_mlp", "weight:layernorm_mlp", "output:layernorm_mlp") + base = ("layernorm_mlp:input", "layernorm_mlp:weight", "layernorm_mlp:output") else: - base = ("grad_output:layernorm_mlp", "grad_input:layernorm_mlp") + base = ("layernorm_mlp:grad_output", "layernorm_mlp:grad_input") return [base[i % len(base)] for i in range(num_quantizers)] def reset_layer_norm_parameters(self) -> None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0ec6ca8cde..bdbf65722e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1316,9 +1316,9 @@ def get_quantizer_roles( ) -> Optional[List[str]]: """Role strings for quantizers used by `Linear`.""" if fwd: - base = ("input:linear", "weight:linear", "output:linear") + base = ("linear:input", "linear:weight", "linear:output") else: - base = ("grad_output:linear", "grad_input:linear") + base = ("linear:grad_output", "linear:grad_input") return [base[i % len(base)] for i in range(num_quantizers)] def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index d0755982be..530bd0f9c1 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -274,11 +274,11 @@ def get_quantizer_roles(self, mode: str) -> Optional[list[str]]: if mode == "forward": # BasicLinear owns input and weight quantizers. # Output quantizer is provided by the next op (as its input quantizer). - return ["input:linear", "weight:linear"] + return ["linear:input", "linear:weight"] if mode == "backward": # BasicLinear owns grad_output quantizer. # Grad_input quantizer is provided by the previous op (as its grad_output quantizer). - return ["grad_output:linear"] + return ["linear:grad_output"] return None def reset_parameters(self) -> None: From 700ea04396f6fe2534f6c4c1f8802f79b16d2f9b Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 27 Jan 2026 14:02:14 -0800 Subject: [PATCH 06/36] Better naming Signed-off-by: Evgeny --- .../pytorch/distributed/run_numerics_exact.py | 14 +++---- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 14 +++---- tests/pytorch/test_custom_recipe.py | 40 +++++++++---------- .../quantization_current_scaling.py | 8 ++-- .../custom_recipes/quantization_nvfp4.py | 10 ++--- 5 files changed, 43 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 6e1e0c8deb..63fa96407e 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -61,33 +61,33 @@ def get_nvfp4_quantizer_factory(): def factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket == "input": + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, # RHT enabled for input ) - elif bucket == "weight": + elif tensor_type == "weight": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), # 2D quantization for weight pow_2_scales=False, with_rht=False, ) - elif bucket == "output": + elif tensor_type == "output": # Output quantization not used return None - elif bucket == "grad_output": + elif tensor_type == "grad_output": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, # RHT enabled for grad_output ) - elif bucket == "grad_input": + elif tensor_type == "grad_input": # Grad input quantization not used return None else: diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index 6ee1a969dd..4ca4d89a1b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -81,33 +81,33 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo def factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket == "input": + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif bucket == "weight": + elif tensor_type == "weight": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, with_rht=False, ) - elif bucket == "output": + elif tensor_type == "output": # Output quantization not used return None - elif bucket == "grad_output": + elif tensor_type == "grad_output": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif bucket == "grad_input": + elif tensor_type == "grad_input": # Grad input quantization not used return None else: diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 9462050ec5..ec0eaf716e 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -91,11 +91,11 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket in ("input", "weight", "output"): + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if bucket in ("grad_output", "grad_input"): + if tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -131,11 +131,11 @@ def test_custom_recipe_grouped_linear_sanity(): def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket in ("input", "weight", "output"): + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if bucket in ("grad_output", "grad_input"): + if tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -196,11 +196,11 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket in ("input", "weight", "output"): + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if bucket in ("grad_output", "grad_input"): + if tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -256,11 +256,11 @@ def test_custom_recipe_ops_linear_2_1_layout(): def quantizer_factory(role): if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket in ("input", "weight", "output"): + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if bucket in ("grad_output", "grad_input"): + if tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -301,11 +301,11 @@ def quantizer_factory(role): if role in counts: counts[role] += 1 if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) - if bucket in ("input", "weight", "output"): + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) + if tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if bucket in ("grad_output", "grad_input"): + if tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index ebda402ec8..e7ef5e7d00 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -24,12 +24,12 @@ def current_scaling_ref_quantizer_factory(role): output = model(input) """ if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) - if bucket in ("input", "weight"): + if tensor_type in ("input", "weight"): dtype = torch.float8_e4m3fn - elif bucket in ("output", "grad_output"): + elif tensor_type in ("output", "grad_output"): dtype = torch.float8_e5m2 else: return None diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 23d6d6e589..535a667412 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -24,24 +24,24 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): output = model(input) """ if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, bucket = role.split(":", 1) + raise ValueError(f"Invalid role: {role}, expected format: ':'") + _, tensor_type = role.split(":", 1) - if bucket == "input": + if tensor_type == "input": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, ) - if bucket == "weight": + if tensor_type == "weight": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if bucket == "grad_output": + if tensor_type == "grad_output": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), From d7ca20bfea0e08adba52a452b2d99325f77948f3 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 17 Feb 2026 15:37:53 +0100 Subject: [PATCH 07/36] Introduce QuantizerRole frozen data class instead of a string Signed-off-by: Evgeny --- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 18 ++---- tests/pytorch/test_custom_recipe.py | 64 ++++++++----------- transformer_engine/common/recipe/__init__.py | 32 ++++++---- transformer_engine/pytorch/__init__.py | 1 + .../quantization_current_scaling.py | 13 ++-- .../custom_recipes/quantization_nvfp4.py | 17 +++-- transformer_engine/pytorch/module/base.py | 5 +- .../pytorch/module/grouped_linear.py | 18 ++++-- .../pytorch/module/layernorm_linear.py | 25 ++++---- .../pytorch/module/layernorm_mlp.py | 28 ++++++-- transformer_engine/pytorch/module/linear.py | 18 ++++-- .../pytorch/ops/basic/basic_linear.py | 12 ++-- transformer_engine/pytorch/ops/op.py | 9 +-- transformer_engine/pytorch/quantization.py | 55 ++++++++++++++-- 14 files changed, 195 insertions(+), 120 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index 4ca4d89a1b..0977d2a9d9 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -76,42 +76,36 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights) Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type == "input": + if role.tensor_type == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif tensor_type == "weight": + elif role.tensor_type == "weight": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, with_rht=False, ) - elif tensor_type == "output": - # Output quantization not used + elif role.tensor_type == "output": return None - elif tensor_type == "grad_output": + elif role.tensor_type == "grad_output": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) - elif tensor_type == "grad_input": - # Grad input quantization not used + elif role.tensor_type == "grad_input": return None else: - # For any other roles, return None return None return factory diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index ec0eaf716e..36e1cc2744 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -16,6 +16,7 @@ GroupedLinear, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.quantization import QuantizerRole import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( nvfp4_ref_rht_2d_quantizer_factory, @@ -90,12 +91,9 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type in ("input", "weight", "output"): + if role.tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -130,12 +128,9 @@ def test_custom_recipe_grouped_linear_sanity(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type in ("input", "weight", "output"): + if role.tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -195,12 +190,9 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type in ("input", "weight", "output"): + if role.tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -255,12 +247,9 @@ def test_custom_recipe_ops_linear_2_1_layout(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type in ("input", "weight", "output"): + if role.tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -288,24 +277,23 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): op = Linear(in_features, out_features, params_dtype=torch.bfloat16) inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) - # Counters per role + # Counters per tensor_type counts = { - "linear:input": 0, - "linear:weight": 0, - "linear:output": 0, - "linear:grad_output": 0, - "linear:grad_input": 0, + "input": 0, + "weight": 0, + "output": 0, + "grad_output": 0, + "grad_input": 0, } def quantizer_factory(role): - if role in counts: - counts[role] += 1 - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type in ("input", "weight", "output"): + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + assert role.module_type == "linear" + if role.tensor_type in counts: + counts[role.tensor_type] += 1 + if role.tensor_type in ("input", "weight", "output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output", "grad_input"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) @@ -319,11 +307,11 @@ def quantizer_factory(role): loss.backward() # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input - assert counts["linear:input"] == 1 - assert counts["linear:weight"] == 1 - assert counts["linear:output"] == 1 - assert counts["linear:grad_output"] == 1 - assert counts["linear:grad_input"] == 1 + assert counts["input"] == 1 + assert counts["weight"] == 1 + assert counts["output"] == 1 + assert counts["grad_output"] == 1 + assert counts["grad_input"] == 1 def test_factories_return_distinct_instances_and_buffers(): diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 2c6b171e9f..e1d4808527 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -492,19 +492,29 @@ class CustomRecipe(Recipe): Parameters ---------- qfactory : Callable - Factory callable that returns a quantizer instance for a - given semantic tensor role. - The callable is typically invoked as:: + Factory callable that returns a quantizer instance for a given `QuantizerRole`. + The callable is invoked as:: qfactory( - role: str, - ) - - Where `role` is one of the following strings for e.g. te.Linear - (stable public contract): - - - forward: "linear:input", "linear:weight", "linear:output" - - backward: "linear:grad_output", "linear:grad_input" + role: QuantizerRole, + ) -> Optional[Quantizer] + + `QuantizerRole` is a frozen dataclass with the following fields: + + - `module_type` (str): TE module class, e.g. `"linear"`, + `"layernorm_linear"`, `"layernorm_mlp"`, `"grouped_linear"`, + `"dpa"`. + - `tensor_type` (str): what tensor is being quantized, e.g. + `"input"`, `"weight"`, `"output"`, `"grad_output"`, + `"grad_input"`. + - `name` (str): caller-provided module instance name (empty + string when not set), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`. + - `position` (str): module-internal sub-slot within compound + modules, e.g. `"fc1"` / `"fc2"` inside `LayerNormMLP` + (empty string for simple modules). + + See `transformer_engine.pytorch.quantization.QuantizerRole` + for full documentation. """ qfactory: Callable[..., Any] diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954b..4880959546 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -48,6 +48,7 @@ from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import is_bf16_available diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index e7ef5e7d00..278505ddc8 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -18,18 +18,17 @@ def current_scaling_ref_quantizer_factory(role): """Factory function for current scaling reference quantizer. - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory) with autocast(recipe=custom_recipe): output = model(input) """ - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - - if tensor_type in ("input", "weight"): + if role.tensor_type in ("input", "weight"): dtype = torch.float8_e4m3fn - elif tensor_type in ("output", "grad_output"): + elif role.tensor_type in ("output", "grad_output"): dtype = torch.float8_e5m2 else: return None diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 535a667412..1c2db63859 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -18,30 +18,29 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): """ Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) - with autocast(fp8_recipe=custom_recipe): + with autocast(recipe=custom_recipe): output = model(input) """ - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - - if tensor_type == "input": + if role.tensor_type == "input": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, ) - if tensor_type == "weight": + if role.tensor_type == "weight": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if tensor_type == "grad_output": + if role.tensor_type == "grad_output": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2961916266..4db37f12f6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -28,6 +28,7 @@ Float8BlockScalingRecipeState, NVFP4BlockScalingRecipeState, FP8GlobalStateManager, + QuantizerRole, RecipeState, ) from ..distributed import ( @@ -752,8 +753,8 @@ def get_quantizer_roles( *, fwd: bool, num_quantizers: int, - ) -> Optional[List[str]]: - """Return an ordered list of role strings for quantizers. + ) -> Optional[List[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. The returned list must have length `num_quantizers`. Returning `None` means "no explicit roles". diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index bfaffdf2ed..781e7d1571 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -21,7 +21,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, cast_if_needed, @@ -729,15 +729,23 @@ def get_quantizer_roles( *, fwd: bool, num_quantizers: int, - ) -> Optional[List[str]]: - """Role strings for quantizers used by `GroupedLinear`. + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``GroupedLinear``. For grouped GEMMs we repeat the same pattern for each GEMM in order. """ + name = self.name or "" if fwd: - base = ("grouped_linear:input", "grouped_linear:weight", "grouped_linear:output") + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="input", name=name), + QuantizerRole(module_type="grouped_linear", tensor_type="weight", name=name), + QuantizerRole(module_type="grouped_linear", tensor_type="output", name=name), + ] else: - base = ("grouped_linear:grad_output", "grouped_linear:grad_input") + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="grad_output", name=name), + QuantizerRole(module_type="grouped_linear", tensor_type="grad_input", name=name), + ] return [base[i % len(base)] for i in range(num_quantizers)] def reset_parameters(self, defer_init=False): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 98db95011e..fe38150d12 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -26,7 +26,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, assert_dim_for_all_gather, @@ -1417,19 +1417,20 @@ def get_quantizer_roles( *, fwd: bool, num_quantizers: int, - ) -> Optional[List[str]]: - """Role strings for quantizers used by `LayerNormLinear`.""" + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormLinear``.""" + name = self.name or "" if fwd: - base = ( - "layernorm_linear:input", - "layernorm_linear:weight", - "layernorm_linear:output", - ) + base = [ + QuantizerRole(module_type="layernorm_linear", tensor_type="input", name=name), + QuantizerRole(module_type="layernorm_linear", tensor_type="weight", name=name), + QuantizerRole(module_type="layernorm_linear", tensor_type="output", name=name), + ] else: - base = ( - "layernorm_linear:grad_output", - "layernorm_linear:grad_input", - ) + base = [ + QuantizerRole(module_type="layernorm_linear", tensor_type="grad_output", name=name), + QuantizerRole(module_type="layernorm_linear", tensor_type="grad_input", name=name), + ] return [base[i % len(base)] for i in range(num_quantizers)] def reset_layer_norm_parameters(self) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 33f4849fa2..a02cb94551 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -27,7 +27,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -1973,13 +1973,29 @@ def get_quantizer_roles( *, fwd: bool, num_quantizers: int, - ) -> Optional[List[str]]: - """Role strings for quantizers used by `LayerNormMLP`.""" + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormMLP``. + + Uses ``position`` to distinguish FC1 and FC2 sub-operations. + """ + name = self.name or "" if fwd: - base = ("layernorm_mlp:input", "layernorm_mlp:weight", "layernorm_mlp:output") + roles = [ + QuantizerRole(module_type="layernorm_mlp", tensor_type="input", name=name, position="fc1"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="weight", name=name, position="fc1"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="output", name=name, position="fc1"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="input", name=name, position="fc2"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="weight", name=name, position="fc2"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="output", name=name, position="fc2"), + ] else: - base = ("layernorm_mlp:grad_output", "layernorm_mlp:grad_input") - return [base[i % len(base)] for i in range(num_quantizers)] + roles = [ + QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_output", name=name, position="fc1"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_input", name=name, position="fc1"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_output", name=name, position="fc2"), + QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_input", name=name, position="fc2"), + ] + return roles[:num_quantizers] def reset_layer_norm_parameters(self) -> None: """Init LN params""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bdbf65722e..b44172eb20 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,7 +25,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, clear_tensor_data, @@ -1313,12 +1313,20 @@ def get_quantizer_roles( *, fwd: bool, num_quantizers: int, - ) -> Optional[List[str]]: - """Role strings for quantizers used by `Linear`.""" + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``Linear``.""" + name = self.name or "" if fwd: - base = ("linear:input", "linear:weight", "linear:output") + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + QuantizerRole(module_type="linear", tensor_type="output", name=name), + ] else: - base = ("linear:grad_output", "linear:grad_input") + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), + ] return [base[i % len(base)] for i in range(num_quantizers)] def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 530bd0f9c1..18f0ce0a62 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,7 +19,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...quantization import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -270,15 +270,19 @@ def num_quantizers(self, mode: str) -> int: return 1 return 0 - def get_quantizer_roles(self, mode: str) -> Optional[list[str]]: + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + name = getattr(self, "name", "") or "" if mode == "forward": # BasicLinear owns input and weight quantizers. # Output quantizer is provided by the next op (as its input quantizer). - return ["linear:input", "linear:weight"] + return [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + ] if mode == "backward": # BasicLinear owns grad_output quantizer. # Grad_input quantizer is provided by the previous op (as its grad_output quantizer). - return ["linear:grad_output"] + return [QuantizerRole(module_type="linear", tensor_type="grad_output", name=name)] return None def reset_parameters(self) -> None: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 1bd0c92b79..6cf92b042b 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Recipe from ..quantization import ( FP8GlobalStateManager, + QuantizerRole, RecipeState, autocast, ) @@ -209,12 +210,12 @@ def num_quantizers( """ return 0 - def get_quantizer_roles(self, mode: str) -> Optional[list[str]]: - """Return an ordered list of role strings for quantizers. + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. The returned list must be aligned with the internal quantizer ordering and - must have length `num_quantizers(mode)` for supported modes. - Returning `None` means "no explicit roles". + must have length ``num_quantizers(mode)`` for supported modes. + Returning ``None`` means "no explicit roles". """ return None diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 06f8cf1321..9e14f9a4dc 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import dataclasses import itertools import functools import warnings @@ -41,9 +42,53 @@ "is_nvfp4_available", "get_default_recipe", "get_align_size_for_quantization", + "QuantizerRole", ] +@dataclasses.dataclass(frozen=True) +class QuantizerRole: + """Identity of a tensor slot requesting a quantizer. + + TE modules populate all fields they know about. + User factories inspect only the fields they care about. + + Fields + ------ + module_type : str + TE module class that emits this role, e.g. + `"linear"`, `"layernorm_linear"`, `"layernorm_mlp"`, + `"grouped_linear"`, `"dpa"`. + tensor_type : str + What tensor is being quantized, in the module's own vocabulary. + GEMM modules: `"input"`, `"weight"`, `"output"`, + `"grad_output"`, `"grad_input"`. + DPA: `"qkv"`, `"o"`, `"s"`, `"dqkv"`, `"do"`, `"dp"`. + name : str + Caller-provided module instance name (e.g. set by the training + framework), e.g. + `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. + Empty string when not provided. + position : str + Module-internal sub-slot. For modules that fuse multiple sequential operations, + e.g. `LayerNormMLP` has `"fc1"` and `"fc2"` sub-slots. + Empty string for simple modules. + """ + + module_type: str + tensor_type: str + name: str = "" + position: str = "" + + def __str__(self) -> str: + parts = [f"{self.module_type}:{self.tensor_type}"] + if self.name: + parts.append(f"name={self.name}") + if self.position: + parts.append(f"position={self.position}") + return "|".join(parts) + + @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" @@ -992,7 +1037,7 @@ def create( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, - roles: Optional[list[str]] = None, + roles: Optional[list[QuantizerRole]] = None, ) -> RecipeState: """Factory method to create the state for a quantization recipe @@ -1006,6 +1051,8 @@ def create( Number of quantizers to create state for. device: torch.device, default = default CUDA device Device for quantized tensors. + roles: list of QuantizerRole, optional + Semantic roles for each quantizer slot. Returns ------- @@ -1035,8 +1082,7 @@ def create( num_quantizers=num_quantizers, device=device, ) - # Optional role strings for quantizers, now only used only by CustomRecipe. - # TODO(negvet): Make all recipe states take roles in their constructors. + # Optional QuantizerRole objects state.roles = roles return state @@ -1388,7 +1434,7 @@ def make_quantizers(self) -> list: qfactory = self.recipe.qfactory out = [] - roles: List[str] + roles: List[QuantizerRole] if getattr(self, "roles", None) is None: raise ValueError("CustomRecipeState requires roles to be set.") roles = self.roles @@ -1399,7 +1445,6 @@ def make_quantizers(self) -> list: ) for i in range(self.num_quantizers): - # Get quantizer from the user defined factory quantizer = qfactory(roles[i]) out.append(quantizer) return out From ed595564ad32c86ed55446b415cba7b74b55615a Mon Sep 17 00:00:00 2001 From: Evgeny Date: Thu, 19 Feb 2026 13:58:46 +0100 Subject: [PATCH 08/36] Shrink module_type vocabulary Signed-off-by: Evgeny --- transformer_engine/common/recipe/__init__.py | 8 +++----- .../pytorch/module/layernorm_linear.py | 10 +++++----- .../pytorch/module/layernorm_mlp.py | 20 +++++++++---------- transformer_engine/pytorch/quantization.py | 9 +++------ 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e1d4808527..171679daaa 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -501,12 +501,10 @@ class CustomRecipe(Recipe): `QuantizerRole` is a frozen dataclass with the following fields: - - `module_type` (str): TE module class, e.g. `"linear"`, - `"layernorm_linear"`, `"layernorm_mlp"`, `"grouped_linear"`, - `"dpa"`. + - `module_type` (str): module type, e.g. + `"linear"`, `"grouped_linear"`, `"dpa"`. - `tensor_type` (str): what tensor is being quantized, e.g. - `"input"`, `"weight"`, `"output"`, `"grad_output"`, - `"grad_input"`. + `"input"`, `"weight"`, `"grad_output"`, etc. - `name` (str): caller-provided module instance name (empty string when not set), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`. - `position` (str): module-internal sub-slot within compound diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fe38150d12..22fe52310b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1422,14 +1422,14 @@ def get_quantizer_roles( name = self.name or "" if fwd: base = [ - QuantizerRole(module_type="layernorm_linear", tensor_type="input", name=name), - QuantizerRole(module_type="layernorm_linear", tensor_type="weight", name=name), - QuantizerRole(module_type="layernorm_linear", tensor_type="output", name=name), + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + QuantizerRole(module_type="linear", tensor_type="output", name=name), ] else: base = [ - QuantizerRole(module_type="layernorm_linear", tensor_type="grad_output", name=name), - QuantizerRole(module_type="layernorm_linear", tensor_type="grad_input", name=name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), ] return [base[i % len(base)] for i in range(num_quantizers)] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a02cb94551..7251c20df4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1981,19 +1981,19 @@ def get_quantizer_roles( name = self.name or "" if fwd: roles = [ - QuantizerRole(module_type="layernorm_mlp", tensor_type="input", name=name, position="fc1"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="weight", name=name, position="fc1"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="output", name=name, position="fc1"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="input", name=name, position="fc2"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="weight", name=name, position="fc2"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="output", name=name, position="fc2"), + QuantizerRole(module_type="linear", tensor_type="input", name=name, position="fc1"), + QuantizerRole(module_type="linear", tensor_type="weight", name=name, position="fc1"), + QuantizerRole(module_type="linear", tensor_type="output", name=name, position="fc1"), + QuantizerRole(module_type="linear", tensor_type="input", name=name, position="fc2"), + QuantizerRole(module_type="linear", tensor_type="weight", name=name, position="fc2"), + QuantizerRole(module_type="linear", tensor_type="output", name=name, position="fc2"), ] else: roles = [ - QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_output", name=name, position="fc1"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_input", name=name, position="fc1"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_output", name=name, position="fc2"), - QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_input", name=name, position="fc2"), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name, position="fc1"), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=name, position="fc1"), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name, position="fc2"), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=name, position="fc2"), ] return roles[:num_quantizers] diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9e14f9a4dc..5fc8b28e24 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -56,14 +56,11 @@ class QuantizerRole: Fields ------ module_type : str - TE module class that emits this role, e.g. - `"linear"`, `"layernorm_linear"`, `"layernorm_mlp"`, - `"grouped_linear"`, `"dpa"`. + Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. tensor_type : str What tensor is being quantized, in the module's own vocabulary. - GEMM modules: `"input"`, `"weight"`, `"output"`, - `"grad_output"`, `"grad_input"`. - DPA: `"qkv"`, `"o"`, `"s"`, `"dqkv"`, `"do"`, `"dp"`. + GEMM modules: `"input"`, `"weight"`, `"grad_output"`, etc. + DPA: `"qkv"`, `"s"`, etc. name : str Caller-provided module instance name (e.g. set by the training framework), e.g. From b1a4aedbf659eab28aa5c490c3561ab7f56510f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 13:32:33 +0000 Subject: [PATCH 09/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/module/layernorm_mlp.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4e0d7456f0..1d7c0910e2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1993,18 +1993,34 @@ def get_quantizer_roles( if fwd: roles = [ QuantizerRole(module_type="linear", tensor_type="input", name=name, position="fc1"), - QuantizerRole(module_type="linear", tensor_type="weight", name=name, position="fc1"), - QuantizerRole(module_type="linear", tensor_type="output", name=name, position="fc1"), + QuantizerRole( + module_type="linear", tensor_type="weight", name=name, position="fc1" + ), + QuantizerRole( + module_type="linear", tensor_type="output", name=name, position="fc1" + ), QuantizerRole(module_type="linear", tensor_type="input", name=name, position="fc2"), - QuantizerRole(module_type="linear", tensor_type="weight", name=name, position="fc2"), - QuantizerRole(module_type="linear", tensor_type="output", name=name, position="fc2"), + QuantizerRole( + module_type="linear", tensor_type="weight", name=name, position="fc2" + ), + QuantizerRole( + module_type="linear", tensor_type="output", name=name, position="fc2" + ), ] else: roles = [ - QuantizerRole(module_type="linear", tensor_type="grad_output", name=name, position="fc1"), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=name, position="fc1"), - QuantizerRole(module_type="linear", tensor_type="grad_output", name=name, position="fc2"), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=name, position="fc2"), + QuantizerRole( + module_type="linear", tensor_type="grad_output", name=name, position="fc1" + ), + QuantizerRole( + module_type="linear", tensor_type="grad_input", name=name, position="fc1" + ), + QuantizerRole( + module_type="linear", tensor_type="grad_output", name=name, position="fc2" + ), + QuantizerRole( + module_type="linear", tensor_type="grad_input", name=name, position="fc2" + ), ] return roles[:num_quantizers] From 6e1ee37e880757b8a5c8cfeb12592ab6b02d5103 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 20 Feb 2026 15:05:12 +0100 Subject: [PATCH 10/36] Fix numerics exact test Signed-off-by: Evgeny --- .../pytorch/distributed/run_numerics_exact.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 63fa96407e..65af23c490 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -56,42 +56,36 @@ def get_nvfp4_quantizer_factory(): enabled. Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if ":" not in role: - raise ValueError(f"Invalid role: {role}, expected format: ':'") - _, tensor_type = role.split(":", 1) - if tensor_type == "input": + if role.tensor_type == "input": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, - with_rht=True, # RHT enabled for input + with_rht=True, ) - elif tensor_type == "weight": + elif role.tensor_type == "weight": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(16, 16), # 2D quantization for weight + quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - elif tensor_type == "output": - # Output quantization not used + elif role.tensor_type == "output": return None - elif tensor_type == "grad_output": + elif role.tensor_type == "grad_output": return quantization_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, - with_rht=True, # RHT enabled for grad_output + with_rht=True, ) - elif tensor_type == "grad_input": - # Grad input quantization not used + elif role.tensor_type == "grad_input": return None else: - # For any other roles, return None return None return factory From b9753f2c32bd09ecfceeef156ab24cb93a4be474 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 20 Feb 2026 16:10:28 +0100 Subject: [PATCH 11/36] Set defaults, make custom recipe forward compatible Signed-off-by: Evgeny --- transformer_engine/pytorch/quantization.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 5fc8b28e24..debe83af3a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -72,8 +72,8 @@ class QuantizerRole: Empty string for simple modules. """ - module_type: str - tensor_type: str + module_type: str = "" + tensor_type: str = "" name: str = "" position: str = "" @@ -1429,18 +1429,23 @@ def __init__( def make_quantizers(self) -> list: qfactory = self.recipe.qfactory - out = [] - roles: List[QuantizerRole] - if getattr(self, "roles", None) is None: - raise ValueError("CustomRecipeState requires roles to be set.") - roles = self.roles + roles: List[QuantizerRole] = getattr(self, "roles", None) + if roles is None: + warnings.warn( + "CustomRecipeState: no QuantizerRole list provided by the module/op. " + "Falling back to bare QuantizerRole() defaults. " + "Override get_quantizer_roles() to provide meaningful roles.", + stacklevel=2, + ) + roles = [QuantizerRole() for _ in range(self.num_quantizers)] if len(roles) != self.num_quantizers: raise ValueError( "CustomRecipeState requires roles to match num_quantizers " f"({len(roles)=} vs {self.num_quantizers=})" ) + out = [] for i in range(self.num_quantizers): quantizer = qfactory(roles[i]) out.append(quantizer) From ad672474774612bc4b93cb2488b8635816576e56 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 20 Feb 2026 16:32:41 +0100 Subject: [PATCH 12/36] remove position from QuantizerRole Signed-off-by: Evgeny --- transformer_engine/common/recipe/__init__.py | 9 ++-- .../pytorch/module/layernorm_mlp.py | 42 ++++--------------- transformer_engine/pytorch/quantization.py | 17 ++++---- 3 files changed, 20 insertions(+), 48 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e3520013b8..2860fcf7f4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -508,15 +508,12 @@ class CustomRecipe(Recipe): `QuantizerRole` is a frozen dataclass with the following fields: - - `module_type` (str): module type, e.g. + - `module_type` (str): module type (empty string when not set), e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. - - `tensor_type` (str): what tensor is being quantized, e.g. - `"input"`, `"weight"`, `"grad_output"`, etc. + - `tensor_type` (str): what tensor is being quantized (empty + string when not set), e.g. `"input"`, `"weight"`, `"grad_output"`. - `name` (str): caller-provided module instance name (empty string when not set), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`. - - `position` (str): module-internal sub-slot within compound - modules, e.g. `"fc1"` / `"fc2"` inside `LayerNormMLP` - (empty string for simple modules). See `transformer_engine.pytorch.quantization.QuantizerRole` for full documentation. diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1d7c0910e2..63adcce6b2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1985,44 +1985,20 @@ def get_quantizer_roles( fwd: bool, num_quantizers: int, ) -> Optional[List[QuantizerRole]]: - """QuantizerRole list for quantizers used by ``LayerNormMLP``. - - Uses ``position`` to distinguish FC1 and FC2 sub-operations. - """ + """QuantizerRole list for quantizers used by ``LayerNormMLP``.""" name = self.name or "" if fwd: - roles = [ - QuantizerRole(module_type="linear", tensor_type="input", name=name, position="fc1"), - QuantizerRole( - module_type="linear", tensor_type="weight", name=name, position="fc1" - ), - QuantizerRole( - module_type="linear", tensor_type="output", name=name, position="fc1" - ), - QuantizerRole(module_type="linear", tensor_type="input", name=name, position="fc2"), - QuantizerRole( - module_type="linear", tensor_type="weight", name=name, position="fc2" - ), - QuantizerRole( - module_type="linear", tensor_type="output", name=name, position="fc2" - ), + base = [ + QuantizerRole(module_type="layernorm_mlp", tensor_type="input", name=name), + QuantizerRole(module_type="layernorm_mlp", tensor_type="weight", name=name), + QuantizerRole(module_type="layernorm_mlp", tensor_type="output", name=name), ] else: - roles = [ - QuantizerRole( - module_type="linear", tensor_type="grad_output", name=name, position="fc1" - ), - QuantizerRole( - module_type="linear", tensor_type="grad_input", name=name, position="fc1" - ), - QuantizerRole( - module_type="linear", tensor_type="grad_output", name=name, position="fc2" - ), - QuantizerRole( - module_type="linear", tensor_type="grad_input", name=name, position="fc2" - ), + base = [ + QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_output", name=name), + QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_input", name=name), ] - return roles[:num_quantizers] + return [base[i % len(base)] for i in range(num_quantizers)] def reset_layer_norm_parameters(self) -> None: """Init LN params""" diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index debe83af3a..806054abba 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -57,33 +57,32 @@ class QuantizerRole: ------ module_type : str Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. + Empty string when not provided. tensor_type : str What tensor is being quantized, in the module's own vocabulary. GEMM modules: `"input"`, `"weight"`, `"grad_output"`, etc. DPA: `"qkv"`, `"s"`, etc. + Empty string when not provided. name : str Caller-provided module instance name (e.g. set by the training framework), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. Empty string when not provided. - position : str - Module-internal sub-slot. For modules that fuse multiple sequential operations, - e.g. `LayerNormMLP` has `"fc1"` and `"fc2"` sub-slots. - Empty string for simple modules. """ module_type: str = "" tensor_type: str = "" name: str = "" - position: str = "" def __str__(self) -> str: - parts = [f"{self.module_type}:{self.tensor_type}"] + parts = [] + if self.module_type: + parts.append(f"module_type={self.module_type}") + if self.tensor_type: + parts.append(f"tensor_type={self.tensor_type}") if self.name: parts.append(f"name={self.name}") - if self.position: - parts.append(f"position={self.position}") - return "|".join(parts) + return "|".join(parts) if parts else "QuantizerRole()" @functools.lru_cache(maxsize=None) From e6be76aaa06efdfa985f4f4d6cab6023ebc5aeba Mon Sep 17 00:00:00 2001 From: Evgeny Date: Fri, 20 Feb 2026 17:26:26 +0100 Subject: [PATCH 13/36] Set good defaults Signed-off-by: Evgeny --- .../quantization_current_scaling.py | 8 +++---- .../custom_recipes/quantization_nvfp4.py | 23 ++++++------------- .../pytorch/module/layernorm_mlp.py | 10 ++++---- transformer_engine/pytorch/quantization.py | 12 ++++++++++ 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 278505ddc8..9febf06b8e 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -20,18 +20,18 @@ def current_scaling_ref_quantizer_factory(role): Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + Backward tensors use E5M2, everything else uses E4M3. + Usage with CustomRecipe and autocast:: custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory) with autocast(recipe=custom_recipe): output = model(input) """ - if role.tensor_type in ("input", "weight"): - dtype = torch.float8_e4m3fn - elif role.tensor_type in ("output", "grad_output"): + if role.tensor_type in ("grad_output", "grad_input"): dtype = torch.float8_e5m2 else: - return None + dtype = torch.float8_e4m3fn return CurrentScalingQuantizerRef( dtype=dtype, rowwise=True, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 1c2db63859..25b7159c27 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -26,28 +26,19 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): with autocast(recipe=custom_recipe): output = model(input) """ - if role.tensor_type == "input": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - if role.tensor_type == "weight": + if role.is_gemm() and role.tensor_type == "weight": return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if role.tensor_type == "grad_output": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - return None + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) def cast_to_fp4x2(x): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 63adcce6b2..b40d656ab2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1989,14 +1989,14 @@ def get_quantizer_roles( name = self.name or "" if fwd: base = [ - QuantizerRole(module_type="layernorm_mlp", tensor_type="input", name=name), - QuantizerRole(module_type="layernorm_mlp", tensor_type="weight", name=name), - QuantizerRole(module_type="layernorm_mlp", tensor_type="output", name=name), + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + QuantizerRole(module_type="linear", tensor_type="output", name=name), ] else: base = [ - QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_output", name=name), - QuantizerRole(module_type="layernorm_mlp", tensor_type="grad_input", name=name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), ] return [base[i % len(base)] for i in range(num_quantizers)] diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 806054abba..0c0cec953f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -58,6 +58,7 @@ class QuantizerRole: module_type : str Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. Empty string when not provided. + See `GEMM_MODULE_TYPES` for the set of GEMM-based module types. tensor_type : str What tensor is being quantized, in the module's own vocabulary. GEMM modules: `"input"`, `"weight"`, `"grad_output"`, etc. @@ -68,12 +69,23 @@ class QuantizerRole: framework), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. Empty string when not provided. + + Class attributes + ---------------- + GEMM_MODULE_TYPES : frozenset of str + Module types that represent GEMM-based operations. """ + GEMM_MODULE_TYPES = frozenset({"linear", "grouped_linear"}) + module_type: str = "" tensor_type: str = "" name: str = "" + def is_gemm(self) -> bool: + """Whether this role belongs to a GEMM-based module.""" + return self.module_type in self.GEMM_MODULE_TYPES + def __str__(self) -> str: parts = [] if self.module_type: From a86fdad67c86b427031ebe1c2811211424d3a69a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 16:27:28 +0000 Subject: [PATCH 14/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/recipe/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 2860fcf7f4..1db8110ae3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -510,7 +510,7 @@ class CustomRecipe(Recipe): - `module_type` (str): module type (empty string when not set), e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. - - `tensor_type` (str): what tensor is being quantized (empty + - `tensor_type` (str): what tensor is being quantized (empty string when not set), e.g. `"input"`, `"weight"`, `"grad_output"`. - `name` (str): caller-provided module instance name (empty string when not set), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`. From d323f66ae2ad4e0af93238bb5931e6f46ca2414b Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 24 Feb 2026 15:05:16 +0100 Subject: [PATCH 15/36] Resolve naming: make every module/op distinguishable via name Signed-off-by: Evgeny --- .../dot_product_attention.py | 5 +++- .../pytorch/attention/multi_head_attention.py | 1 + .../pytorch/module/layernorm_mlp.py | 25 +++++++++++++------ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..875ba21fab 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -284,6 +284,8 @@ class DotProductAttention(TransformerEngineBaseModule): `_). :math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``, and :math:`\text{max_logit}` is of shape ``[h]``. + name : Optional[str], default = None + module instance name. Parallelism parameters ---------------------- @@ -343,8 +345,9 @@ def __init__( softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name=name) self.logger = logging.getLogger("DotProductAttention") self.logger.setLevel(attn_log._log_level) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..328abc3cf7 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -461,6 +461,7 @@ def __init__( layer_number=self.layer_number, attention_type=self.attention_type, softmax_type=self.softmax_type, + name=name + ".core_attention" if name is not None else None, ) # Linear diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b40d656ab2..11a40adbf3 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1985,18 +1985,29 @@ def get_quantizer_roles( fwd: bool, num_quantizers: int, ) -> Optional[List[QuantizerRole]]: - """QuantizerRole list for quantizers used by ``LayerNormMLP``.""" - name = self.name or "" + """QuantizerRole list for quantizers used by ``LayerNormMLP``. + + Each internal GEMM (fc1, fc2) gets a distinct name suffix so that + custom-recipe factories can target them individually. + """ + base_name = self.name or "" + fc1_name = f"{base_name}.fc1" if base_name else "fc1" + fc2_name = f"{base_name}.fc2" if base_name else "fc2" if fwd: base = [ - QuantizerRole(module_type="linear", tensor_type="input", name=name), - QuantizerRole(module_type="linear", tensor_type="weight", name=name), - QuantizerRole(module_type="linear", tensor_type="output", name=name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="output", name=fc2_name), ] else: base = [ - QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="grad_input", name=fc2_name), ] return [base[i % len(base)] for i in range(num_quantizers)] From c9eae0f79066d9232bd76711fc629e0c58c476b9 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 24 Feb 2026 18:22:10 +0100 Subject: [PATCH 16/36] Configure output/grad_input roles, defaults to None Signed-off-by: Evgeny --- .../pytorch/attention/multi_head_attention.py | 38 ++++++++++- transformer_engine/pytorch/module/base.py | 68 +++++++++++++++++++ .../pytorch/module/grouped_linear.py | 9 ++- .../pytorch/module/layernorm_linear.py | 14 +++- .../pytorch/module/layernorm_mlp.py | 17 +++-- transformer_engine/pytorch/module/linear.py | 14 +++- 6 files changed, 146 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 328abc3cf7..09296991d7 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, QuantizerRole from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm @@ -479,6 +479,40 @@ def __init__( **common_gemm_kwargs, ) + def _update_output_quantizer_roles( + self, + qkv_fp8_output: bool, + proj_fp8_grad: bool, + ) -> None: + """Set output / grad-input quantizer roles on QKV and proj linears. + + When the QKV linear's output feeds directly into DPA (``fp8_mha``), + the role is switched from the default linear-consumer assumption to + DPA-consumer roles. Otherwise roles are reset to ``None`` so the + modules fall back to their defaults. + """ + dpa_name = self.core_attention.name or "" + qkv_output_role = ( + QuantizerRole(module_type="dpa", tensor_type="qkv", name=dpa_name) + if qkv_fp8_output else None + ) + proj_grad_input_role = ( + QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) + if proj_fp8_grad else None + ) + if self.attention_type == "self": + if self.input_layernorm: + self.layernorm_qkv.output_quantizer_role = qkv_output_role + else: + self.qkv.output_quantizer_role = qkv_output_role + elif self.attention_type == "cross": + if self.input_layernorm: + self.layernorm_query.output_quantizer_role = qkv_output_role + else: + self.query_layer.output_quantizer_role = qkv_output_role + self.key_value.output_quantizer_role = qkv_output_role + self.proj.grad_input_quantizer_role = proj_grad_input_role + def _create_qk_norm_modules( self, qk_norm_type: Optional[str], @@ -796,6 +830,8 @@ def forward( # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling + self._update_output_quantizer_roles(qkv_fp8_output, proj_fp8_grad) + layernorm_output = None if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b6bf44b82e..814a818634 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -632,6 +632,8 @@ def __init__(self, name: Optional[str] = None) -> None: self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_store = None + self._output_quantizer_role: Optional[QuantizerRole] = None + self._grad_input_quantizer_role: Optional[QuantizerRole] = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -652,6 +654,72 @@ def module_setattr(self, name: str, value: Any) -> None: """ super().__setattr__(name, value) + @property + def output_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the forward output quantizer. + + When set, overrides the default role used by :meth:`get_quantizer_roles` + for the forward-pass output quantizer slot. Setting this after + quantizers have been created forces their recreation on the next + forward pass. + + See also :attr:`grad_input_quantizer_role` for the backward-pass + counterpart. + """ + return self._output_quantizer_role + + @output_quantizer_role.setter + def output_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._output_quantizer_role: + return + self._output_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + @property + def grad_input_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the grad-input quantizer. + + Backward-pass counterpart of :attr:`output_quantizer_role`. + """ + return self._grad_input_quantizer_role + + @grad_input_quantizer_role.setter + def grad_input_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._grad_input_quantizer_role: + return + self._grad_input_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + def _warn_missing_output_quantizer_role( + self, + fp8_output: bool, + fp8_grad: bool, + ) -> None: + """Warn when quantized output is requested but no consumer role is set. + + Only relevant for ``CustomRecipe`` where the ``qfactory`` dispatches + on roles. Built-in recipes ignore role metadata. + """ + recipe = FP8GlobalStateManager.get_fp8_recipe() + if not recipe.custom(): + return + if fp8_output and self._output_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_output=True but " + "output_quantizer_role is not set. The CustomRecipe qfactory " + "will receive None for the output quantizer role.", + stacklevel=3, + ) + if fp8_grad and self._grad_input_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_grad=True but " + "grad_input_quantizer_role is not set. The CustomRecipe " + "qfactory will receive None for the grad-input quantizer role.", + stacklevel=3, + ) + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ Delayed scaling only. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d41221554e..a9668f9972 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -751,19 +751,22 @@ def get_quantizer_roles( ) -> Optional[List[QuantizerRole]]: """QuantizerRole list for quantizers used by ``GroupedLinear``. - For grouped GEMMs we repeat the same pattern for each GEMM in order. + For grouped GEMMs we repeat the same pattern for each GEMM in + order. The output (fwd) and grad-input (bwd) slots default to + ``None`` (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. """ name = self.name or "" if fwd: base = [ QuantizerRole(module_type="grouped_linear", tensor_type="input", name=name), QuantizerRole(module_type="grouped_linear", tensor_type="weight", name=name), - QuantizerRole(module_type="grouped_linear", tensor_type="output", name=name), + self._output_quantizer_role, ] else: base = [ QuantizerRole(module_type="grouped_linear", tensor_type="grad_output", name=name), - QuantizerRole(module_type="grouped_linear", tensor_type="grad_input", name=name), + self._grad_input_quantizer_role, ] return [base[i % len(base)] for i in range(num_quantizers)] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 22fe52310b..a30efd0ccc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1418,18 +1418,23 @@ def get_quantizer_roles( fwd: bool, num_quantizers: int, ) -> Optional[List[QuantizerRole]]: - """QuantizerRole list for quantizers used by ``LayerNormLinear``.""" + """QuantizerRole list for quantizers used by ``LayerNormLinear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ name = self.name or "" if fwd: base = [ QuantizerRole(module_type="linear", tensor_type="input", name=name), QuantizerRole(module_type="linear", tensor_type="weight", name=name), - QuantizerRole(module_type="linear", tensor_type="output", name=name), + self._output_quantizer_role, ] else: base = [ QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), + self._grad_input_quantizer_role, ] return [base[i % len(base)] for i in range(num_quantizers)] @@ -1631,6 +1636,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 11a40adbf3..81faa464c5 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1989,6 +1989,12 @@ def get_quantizer_roles( Each internal GEMM (fc1, fc2) gets a distinct name suffix so that custom-recipe factories can target them individually. + + The module's final output (fc2 fwd) and final grad (fc1 bwd) + slots default to ``None`` (unknown consumer). Set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to provide consumer identity. Internal boundaries use fixed + roles with known consumer identity. """ base_name = self.name or "" fc1_name = f"{base_name}.fc1" if base_name else "fc1" @@ -1997,17 +2003,17 @@ def get_quantizer_roles( base = [ QuantizerRole(module_type="linear", tensor_type="input", name=fc1_name), QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name), - QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name), - QuantizerRole(module_type="linear", tensor_type="output", name=fc2_name), + self._output_quantizer_role, ] else: base = [ QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=fc1_name), + self._grad_input_quantizer_role, QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), ] return [base[i % len(base)] for i in range(num_quantizers)] @@ -2218,6 +2224,9 @@ def forward( return out def _get_quantizers(self, fp8_output, is_grad_enabled): + if self.fp8: + self._warn_missing_output_quantizer_role(fp8_output, False) + ( fc1_input_quantizer, fc1_output_quantizer, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b44172eb20..27ed63a0fe 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1314,18 +1314,23 @@ def get_quantizer_roles( fwd: bool, num_quantizers: int, ) -> Optional[List[QuantizerRole]]: - """QuantizerRole list for quantizers used by ``Linear``.""" + """QuantizerRole list for quantizers used by ``Linear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ name = self.name or "" if fwd: base = [ QuantizerRole(module_type="linear", tensor_type="input", name=name), QuantizerRole(module_type="linear", tensor_type="weight", name=name), - QuantizerRole(module_type="linear", tensor_type="output", name=name), + self._output_quantizer_role, ] else: base = [ QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), - QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), + self._grad_input_quantizer_role, ] return [base[i % len(base)] for i in range(num_quantizers)] @@ -1499,6 +1504,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None From ea3c135159aeae765cf297523719d45f011222fb Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 24 Feb 2026 18:47:40 +0100 Subject: [PATCH 17/36] Remove is_gemm() Signed-off-by: Evgeny --- .../custom_recipes/quantization_current_scaling.py | 6 ++---- .../pytorch/custom_recipes/quantization_nvfp4.py | 7 ++++++- transformer_engine/pytorch/quantization.py | 14 +------------- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 9febf06b8e..0034b739cb 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -28,10 +28,8 @@ def current_scaling_ref_quantizer_factory(role): with autocast(recipe=custom_recipe): output = model(input) """ - if role.tensor_type in ("grad_output", "grad_input"): - dtype = torch.float8_e5m2 - else: - dtype = torch.float8_e4m3fn + is_backward = role is not None and role.tensor_type == "grad_output" + dtype = torch.float8_e5m2 if is_backward else torch.float8_e4m3fn return CurrentScalingQuantizerRef( dtype=dtype, rowwise=True, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 25b7159c27..ec191d7b9d 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -26,7 +26,12 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): with autocast(recipe=custom_recipe): output = model(input) """ - if role.is_gemm() and role.tensor_type == "weight": + is_weight_tensor_in_gemm = ( + role is not None and + role.module_type in ("linear", "grouped_linear") and + role.tensor_type == "weight" + ) + if is_weight_tensor_in_gemm: # 2D quantization for weights in GEMM-based modules return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c0cec953f..aeb36edc50 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -58,10 +58,9 @@ class QuantizerRole: module_type : str Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. Empty string when not provided. - See `GEMM_MODULE_TYPES` for the set of GEMM-based module types. tensor_type : str What tensor is being quantized, in the module's own vocabulary. - GEMM modules: `"input"`, `"weight"`, `"grad_output"`, etc. + Linear modules: `"input"`, `"weight"`, `"grad_output"`, etc. DPA: `"qkv"`, `"s"`, etc. Empty string when not provided. name : str @@ -69,23 +68,12 @@ class QuantizerRole: framework), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. Empty string when not provided. - - Class attributes - ---------------- - GEMM_MODULE_TYPES : frozenset of str - Module types that represent GEMM-based operations. """ - GEMM_MODULE_TYPES = frozenset({"linear", "grouped_linear"}) - module_type: str = "" tensor_type: str = "" name: str = "" - def is_gemm(self) -> bool: - """Whether this role belongs to a GEMM-based module.""" - return self.module_type in self.GEMM_MODULE_TYPES - def __str__(self) -> str: parts = [] if self.module_type: From aaf980fdf3ed75c50baf3c9959600a0700d99636 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Feb 2026 17:48:35 +0000 Subject: [PATCH 18/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/multi_head_attention.py | 6 ++++-- .../pytorch/custom_recipes/quantization_nvfp4.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 09296991d7..78a1c4dde6 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -494,11 +494,13 @@ def _update_output_quantizer_roles( dpa_name = self.core_attention.name or "" qkv_output_role = ( QuantizerRole(module_type="dpa", tensor_type="qkv", name=dpa_name) - if qkv_fp8_output else None + if qkv_fp8_output + else None ) proj_grad_input_role = ( QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) - if proj_fp8_grad else None + if proj_fp8_grad + else None ) if self.attention_type == "self": if self.input_layernorm: diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index ec191d7b9d..0b29977fb4 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -27,11 +27,11 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): output = model(input) """ is_weight_tensor_in_gemm = ( - role is not None and - role.module_type in ("linear", "grouped_linear") and - role.tensor_type == "weight" + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" ) - if is_weight_tensor_in_gemm: # 2D quantization for weights in GEMM-based modules + if is_weight_tensor_in_gemm: # 2D quantization for weights in GEMM-based modules return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), From aad3512d2c0a5a297b89c78f37ad4f1344f67637 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Wed, 25 Feb 2026 14:30:04 +0000 Subject: [PATCH 19/36] Enable base recipes via CustomRecipe and quantization factories Signed-off-by: Evgeny --- .../pytorch/distributed/run_numerics_exact.py | 8 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- .../nvfp4/test_nvfp4_group_quantize.py | 2 +- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 8 +- .../nvfp4/test_nvfp4_quantize_exact.py | 2 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 2 +- tests/pytorch/test_custom_recipe.py | 119 ++++++++++++- .../test_float8_current_scaling_exact.py | 2 +- .../quantization_recipes_base.py | 160 ++++++++++++++++++ ...py => quantization_ref_current_scaling.py} | 0 ...ion_nvfp4.py => quantization_ref_nvfp4.py} | 0 11 files changed, 291 insertions(+), 14 deletions(-) create mode 100644 transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py rename transformer_engine/pytorch/custom_recipes/{quantization_current_scaling.py => quantization_ref_current_scaling.py} (100%) rename transformer_engine/pytorch/custom_recipes/{quantization_nvfp4.py => quantization_ref_nvfp4.py} (100%) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 65af23c490..7ddbb18077 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -22,7 +22,7 @@ ) from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils from run_layer_with_overlap import _compare_tensors @@ -61,14 +61,14 @@ def get_nvfp4_quantizer_factory(): def factory(role): if role.tensor_type == "input": - return quantization_nvfp4.NVFP4QuantizerRef( + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=True, ) elif role.tensor_type == "weight": - return quantization_nvfp4.NVFP4QuantizerRef( + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, @@ -77,7 +77,7 @@ def factory(role): elif role.tensor_type == "output": return None elif role.tensor_type == "grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..8afb103056 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,7 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 01a4a01205..019c5bd566 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -13,7 +13,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index 0977d2a9d9..d727433a28 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -6,7 +6,7 @@ import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils @@ -81,14 +81,14 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo def factory(role): if role.tensor_type == "input": - return quantization_nvfp4.NVFP4QuantizerRef( + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, with_rht=with_rht, ) elif role.tensor_type == "weight": - return quantization_nvfp4.NVFP4QuantizerRef( + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, @@ -97,7 +97,7 @@ def factory(role): elif role.tensor_type == "output": return None elif role.tensor_type == "grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(1, 16), pow_2_scales=False, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..99f0f5cdd6 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -7,7 +7,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f54..a9178b25d8 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -12,7 +12,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 36e1cc2744..036442575b 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -18,7 +18,13 @@ ) from transformer_engine.pytorch.quantization import QuantizerRole import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( +from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + current_scaling_quantizer_factory, + mxfp8_quantizer_factory, + float8_block_scaling_quantizer_factory, + nvfp4_quantizer_factory, +) +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import ( nvfp4_ref_rht_2d_quantizer_factory, ) @@ -333,3 +339,114 @@ def factory(): # Mutating one should not affect the other q1.scale.fill_(123.0) assert not torch.equal(q1.scale, q2.scale) + + +def _run_linear_fwd_bwd(model, inp, recipe): + """Run forward + backward with a given recipe and return (output, inp.grad, param grads).""" + with autocast(enabled=True, recipe=recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + param_grads = {n: p.grad.clone() for n, p in model.named_parameters() if p.grad is not None} + return out.clone(), inp.grad.clone(), param_grads + + +def _make_pair(in_features=128, out_features=128, batch=32, seed=42): + """Create a pair of identical Linear models and matching inputs.""" + torch.manual_seed(seed) + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus.load_state_dict(model_ref.state_dict()) + + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_cus = base_inp.clone().detach().requires_grad_(True) + return model_ref, model_cus, inp_ref, inp_cus + + +def _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus): + """Assert exact match of outputs and all gradients.""" + assert torch.allclose(out_ref, out_cus, rtol=0.0, atol=0.0), ( + f"Forward mismatch: max diff = {(out_ref - out_cus).abs().max()}" + ) + assert torch.allclose(grad_ref, grad_cus, rtol=0.0, atol=0.0), ( + f"Input grad mismatch: max diff = {(grad_ref - grad_cus).abs().max()}" + ) + for name in pgrads_ref: + assert torch.allclose(pgrads_ref[name], pgrads_cus[name], rtol=0.0, atol=0.0), ( + f"Param grad '{name}' mismatch: max diff = " + f"{(pgrads_ref[name] - pgrads_cus[name]).abs().max()}" + ) + + +def test_factory_matches_current_scaling(): + """current_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8CurrentScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8CurrentScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=current_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_mxfp8(): + """mxfp8_quantizer_factory should produce bit-identical results + to the built-in MXFP8BlockScaling recipe.""" + available, reason = te.is_mxfp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"MXFP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.MXFP8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=mxfp8_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_block_scaling(): + """float8_block_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8BlockScaling recipe.""" + available = te.is_fp8_block_scaling_available() + if not torch.cuda.is_available() or not available: + pytest.skip("Float8 block scaling unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=float8_block_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_nvfp4(): + """nvfp4_quantizer_factory should produce bit-identical results + to the built-in NVFP4BlockScaling recipe.""" + available = te.is_nvfp4_available() + if not torch.cuda.is_available() or not available: + pytest.skip("NVFP4 unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.NVFP4BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=nvfp4_quantizer_factory) + ) + + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 99ab9c4984..3b964a5af9 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -14,7 +14,7 @@ from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.custom_recipes.quantization import MMParams -from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import ( +from transformer_engine.pytorch.custom_recipes.quantization_ref_current_scaling import ( CurrentScalingQuantizerRef, ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py new file mode 100644 index 0000000000..0823a4c7bf --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples using real silicon quantizers. + +Each factory below replicates the behaviour of built-in TE recipe but via the +``CustomRecipe`` + ``qfactory`` interface. This is useful when you want to +start from a known-good recipe and then selectively override quantizer settings +for specific layers / tensor types. + +Usage (any factory):: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + nvfp4_quantizer_factory, + ) + + recipe = CustomRecipe(qfactory=nvfp4_quantizer_factory) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def current_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8CurrentScalingQuantizer": + """Factory that mirrors :class:`Float8CurrentScaling` recipe defaults. + + * Forward tensors (input, weight) → E4M3 + * Backward tensors (grad_output) → E5M2 + """ + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + is_backward = role is not None and role.tensor_type == "grad_output" + fp8_dtype = tex.DType.kFloat8E5M2 if is_backward else tex.DType.kFloat8E4M3 + + return Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + force_pow_2_scales=False, # constrain scale to powers of 2 + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + ) + + +def mxfp8_quantizer_factory( + role: Optional[QuantizerRole], +) -> "MXFP8Quantizer": + """Factory that mirrors :class:`MXFP8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Block size 32, power-of-2 (E8M0) scales + """ + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def float8_block_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8BlockQuantizer": + """Factory that mirrors :class:`Float8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Weights use 2D block scaling, everything else uses 1D + * Power-of-2 scales by default + """ + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + + is_weight = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + block_scaling_dim = 2 if is_weight else 1 + + return Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + force_pow_2_scales=True, + block_scaling_dim=block_scaling_dim, # 1 = 1D (1×128), 2 = 2D (128×128) + ) + + +def nvfp4_quantizer_factory( + role: Optional[QuantizerRole], +) -> "NVFP4Quantizer": + """Factory that mirrors :class:`NVFP4BlockScaling` recipe defaults. + + * All tensors quantized to E2M1 (FP4) + * Weights: 2D quantization (16x16 blocks), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + + Quantizer knobs: + fp4_dtype - E2M1 (only supported format) + with_rht - randomized Hadamard transform (smooths outliers) + with_post_rht_amax - recompute amax after RHT (should match with_rht) + with_2d_quantization - 16x16 2D blocks (vs 1x16 1D) + stochastic_rounding - probabilistic rounding to reduce quant bias + with_random_sign_mask - random sign flip in the Hadamard matrix + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type in ("linear", "grouped_linear") + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + # For input and unknown roles + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py similarity index 100% rename from transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py similarity index 100% rename from transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py From 8d7c91f89ff4dbd2ed3330644c5972a83454919b Mon Sep 17 00:00:00 2001 From: Evgeny Date: Wed, 25 Feb 2026 16:23:46 +0000 Subject: [PATCH 20/36] Add factory example - NVFP4 for Linear, MXFP8 for GroupedLinear Signed-off-by: Evgeny --- .../quantization_factory_examples.py | 106 ++++++++++++++++++ transformer_engine/pytorch/quantization.py | 4 + 2 files changed, 110 insertions(+) create mode 100644 transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py diff --git a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py new file mode 100644 index 0000000000..5c563da8a9 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples. + +Demonstrates how to use the ``CustomRecipe`` + ``qfactory`` interface to apply +*different* quantization recipes to different module/tensor types/instances within the same model. + +Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_mxfp8_grouped_linear_factory, + ) + + recipe = CustomRecipe(qfactory=nvfp4_linear_mxfp8_grouped_linear_factory) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def nvfp4_linear_mxfp8_grouped_linear_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, MXFP8 for ``GroupedLinear``. + + Dispatch logic: + * ``role.module_type == "grouped_linear"`` -> MXFP8 (E4M3, block-32) + * everything else (``"linear"`` or unknown) -> NVFP4 (E2M1) + + NVFP4 settings follow the built-in ``NVFP4BlockScaling`` defaults: + * Weights: 2D quantization (16x16), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + """ + is_grouped_linear = role is not None and role.module_type == "grouped_linear" + + if is_grouped_linear: + return _make_mxfp8_quantizer() + + return _make_nvfp4_quantizer(role) + + +def _make_mxfp8_quantizer(): + """Return an MXFP8 quantizer with default settings (E4M3, block-32, E8M0 scales).""" + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): + """Return an NVFP4 quantizer configured per tensor role. + + Mirrors :class:`NVFP4BlockScaling` recipe defaults. + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type == "linear" + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index aeb36edc50..ff3c0b9a1c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -53,6 +53,10 @@ class QuantizerRole: TE modules populate all fields they know about. User factories inspect only the fields they care about. + .. warning:: + **EXPERIMENTAL**: QuantizerRole is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + Fields ------ module_type : str From 736cd7220ec52469d83160fabaa8c2d974cb5a52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Feb 2026 16:24:55 +0000 Subject: [PATCH 21/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_custom_recipe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 036442575b..5ac72ccf56 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -366,12 +366,12 @@ def _make_pair(in_features=128, out_features=128, batch=32, seed=42): def _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus): """Assert exact match of outputs and all gradients.""" - assert torch.allclose(out_ref, out_cus, rtol=0.0, atol=0.0), ( - f"Forward mismatch: max diff = {(out_ref - out_cus).abs().max()}" - ) - assert torch.allclose(grad_ref, grad_cus, rtol=0.0, atol=0.0), ( - f"Input grad mismatch: max diff = {(grad_ref - grad_cus).abs().max()}" - ) + assert torch.allclose( + out_ref, out_cus, rtol=0.0, atol=0.0 + ), f"Forward mismatch: max diff = {(out_ref - out_cus).abs().max()}" + assert torch.allclose( + grad_ref, grad_cus, rtol=0.0, atol=0.0 + ), f"Input grad mismatch: max diff = {(grad_ref - grad_cus).abs().max()}" for name in pgrads_ref: assert torch.allclose(pgrads_ref[name], pgrads_cus[name], rtol=0.0, atol=0.0), ( f"Param grad '{name}' mismatch: max diff = " From b6bfdf827a7f6e19a4ff94319bc59126941dd10f Mon Sep 17 00:00:00 2001 From: Evgeny Date: Wed, 25 Feb 2026 16:42:40 +0000 Subject: [PATCH 22/36] Fix custom recipe test Signed-off-by: Evgeny --- tests/pytorch/test_custom_recipe.py | 56 +++++++++++++++++------------ 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 5ac72ccf56..d4ac8f99a7 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -97,9 +97,9 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): - if role.tensor_type in ("input", "weight", "output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role.tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -134,9 +134,9 @@ def test_custom_recipe_grouped_linear_sanity(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role.tensor_type in ("input", "weight", "output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role.tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -196,9 +196,9 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): - if role.tensor_type in ("input", "weight", "output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role.tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -206,7 +206,12 @@ def quantizer_factory(role): with autocast(enabled=True, recipe=custom_recipe): out_custom = model_custom(inp_custom) - # Assert dtypes for custom quantizers match reference mapping + # Assert dtypes for custom quantizers match reference mapping. + # The output (fwd) and grad_input (bwd) slots receive role=None + # (unknown consumer) and get E4M3 from our factory. The reference + # recipe uses E4M3 for fwd output and E5M2 for bwd grad_input, + # but these quantizers are typically unused so the mismatch doesn't + # affect GEMM results. cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] @@ -216,7 +221,7 @@ def quantizer_factory(role): assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 - assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E4M3 # role=None fallback loss_custom = (out_custom.float() * scale.view(1, -1)).sum() loss_custom.backward() @@ -253,9 +258,9 @@ def test_custom_recipe_ops_linear_2_1_layout(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role.tensor_type in ("input", "weight", "output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role.tensor_type in ("grad_output", "grad_input"): + if role.tensor_type in ("grad_output"): return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -283,41 +288,46 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): op = Linear(in_features, out_features, params_dtype=torch.bfloat16) inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) - # Counters per tensor_type + # Counters per tensor_type. The output (fwd) and grad_input (bwd) + # slots have role=None by default (unknown consumer), so we count + # those separately. counts = { "input": 0, "weight": 0, - "output": 0, "grad_output": 0, - "grad_input": 0, + None: 0, } def quantizer_factory(role): + if role is None: + counts[None] += 1 + return Float8CurrentScalingQuantizer( + tex.DType.kFloat8E4M3, device=torch.device("cuda") + ) assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" assert role.module_type == "linear" if role.tensor_type in counts: counts[role.tensor_type] += 1 - if role.tensor_type in ("input", "weight", "output"): - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if role.tensor_type in ("grad_output", "grad_input"): - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + if role.tensor_type == "grad_output": + return Float8CurrentScalingQuantizer( + tex.DType.kFloat8E5M2, device=torch.device("cuda") + ) + return Float8CurrentScalingQuantizer( + tex.DType.kFloat8E4M3, device=torch.device("cuda") + ) custom = recipe.CustomRecipe(qfactory=quantizer_factory) - # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), - # and backward to build 2 quantizers (cycled from 1 factory). with autocast(enabled=True, recipe=custom): out = op(inp) loss = out.float().sum() loss.backward() - # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input + # Forward: input, weight, output(None); backward: grad_output, grad_input(None) assert counts["input"] == 1 assert counts["weight"] == 1 - assert counts["output"] == 1 assert counts["grad_output"] == 1 - assert counts["grad_input"] == 1 + assert counts[None] == 2, f"Expected 2 None roles (output + grad_input), got {counts[None]}" def test_factories_return_distinct_instances_and_buffers(): From 41656ab738dab84f12f31156f7bee576b4cb3108 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Feb 2026 16:44:57 +0000 Subject: [PATCH 23/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_custom_recipe.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index d4ac8f99a7..20b76d67db 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -301,20 +301,14 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): def quantizer_factory(role): if role is None: counts[None] += 1 - return Float8CurrentScalingQuantizer( - tex.DType.kFloat8E4M3, device=torch.device("cuda") - ) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" assert role.module_type == "linear" if role.tensor_type in counts: counts[role.tensor_type] += 1 if role.tensor_type == "grad_output": - return Float8CurrentScalingQuantizer( - tex.DType.kFloat8E5M2, device=torch.device("cuda") - ) - return Float8CurrentScalingQuantizer( - tex.DType.kFloat8E4M3, device=torch.device("cuda") - ) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) custom = recipe.CustomRecipe(qfactory=quantizer_factory) From cca370c1d3efa9dc69cd56e3965dee23786e3031 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Thu, 26 Feb 2026 11:17:38 +0000 Subject: [PATCH 24/36] Test fine-grained quantization targets Signed-off-by: Evgeny --- tests/pytorch/test_custom_recipe.py | 199 ++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 20b76d67db..3855f317c6 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -454,3 +454,202 @@ def test_factory_matches_nvfp4(): ) _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_custom_recipe_quantization_targets(): + """Validate fine-grained per-module quantization targeting via QuantizerRole. + + Four transformer layers, each assembled at a different abstraction level. + The default recipe is NVFP4; specific modules are overridden: + + Layer 0 - ``TransformerLayer`` (name="tl0") -> all MXFP8 + Layer 1 - ``TransformerLayer`` (name="tl1") -> NVFP4 (default), + except fc2 overridden to MXFP8 + Layer 2 - ``MultiheadAttention`` + ``LayerNormMLP`` + (name prefix "tl2") -> NVFP4 (default), + except qkv and fc1 overridden to Float8 block-scaling + Layer 3 - Individual blocks (name prefix "tl3") -> NVFP4 (default), + except proj overridden to Float8 current-scaling + + The test validates that: + * The factory receives QuantizerRole objects with correct names + * Different quantizer types are dispatched per module + * Forward + backward complete successfully through all four layers + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_mxfp8_available(): + pytest.skip("MXFP8 unsupported on this device") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + if not te.is_fp8_block_scaling_available(): + pytest.skip("Float8 block scaling unsupported on this device") + + torch.manual_seed(42) + + H = 64 # hidden_size + FFN = 64 # ffn_hidden_size + NH = 4 # num_heads + KV = H // NH # kv_channels + B = 4 # batch + S = 8 # seq_len + common = dict(params_dtype=torch.bfloat16, bias=False) + + # Layer 0: TransformerLayer -> MXFP8 + tl0 = te.TransformerLayer( + H, FFN, NH, hidden_dropout=0.0, attention_dropout=0.0, name="tl0", **common, + ).cuda() + + # Layer 1: TransformerLayer -> NVFP4 default, fc2 overridden to MXFP8 + tl1 = te.TransformerLayer( + H, FFN, NH, hidden_dropout=0.0, attention_dropout=0.0, name="tl1", **common, + ).cuda() + + # Layer 2: MHA + LayerNormMLP -> NVFP4 default, qkv and fc1 to block-scaling + tl2_mha = te.MultiheadAttention( + H, NH, KV, attention_dropout=0.0, input_layernorm=True, return_bias=True, + name="tl2.self_attention", **common, + ).cuda() + tl2_mlp = LayerNormMLP(H, FFN, name="tl2.layernorm_mlp", **common).cuda() + + # Layer 3: Individual blocks with DPA -> NVFP4 default, proj to current-scaling + tl3_qkv = LayerNormLinear(H, 3 * H, name="tl3.qkv", **common).cuda() + tl3_dpa = te.DotProductAttention(NH, KV, attention_dropout=0.0, name="tl3.core_attention") + tl3_proj = Linear(H, H, name="tl3.proj", **common).cuda() + tl3_fc1 = LayerNormLinear(H, FFN, name="tl3.fc1", **common).cuda() + tl3_fc2 = Linear(FFN, H, name="tl3.fc2", **common).cuda() + + # ------------------------------------------------------------------ + # Recording + dispatching factory + # ------------------------------------------------------------------ + recorded_roles = [] + + def targeting_factory(role): + recorded_roles.append(role) + + if role is None: + return nvfp4_quantizer_factory(role) + + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + + # Layer 0 (tl0.*): all MXFP8 + if role.name.startswith("tl0"): + return mxfp8_quantizer_factory(role) + + # Layer 1 (tl1.*): NVFP4 default, but fc2 overridden to MXFP8 + if role.name == "tl1.layernorm_mlp.fc2": + return mxfp8_quantizer_factory(role) + + # Layer 2: block scaling for qkv and fc1, rest falls through to default + if role.name == "tl2.self_attention.layernorm_linear_qkv": + return float8_block_scaling_quantizer_factory(role) + if role.name == "tl2.layernorm_mlp.fc1": + return float8_block_scaling_quantizer_factory(role) + + # Layer 3: current-scaling for proj, rest falls through to default + if role.name == "tl3.proj": + return current_scaling_quantizer_factory(role) + + # Default: NVFP4 + return nvfp4_quantizer_factory(role) + + custom_recipe = recipe.CustomRecipe(qfactory=targeting_factory) + + # ------------------------------------------------------------------ + # Forward + backward + # ------------------------------------------------------------------ + inp = torch.randn(S, B, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + # Layer 0 & 1: TransformerLayer + h = tl1(tl0(inp)) + + # Layer 2: MHA + residual + LayerNormMLP + residual + attn_out, _ = tl2_mha(h) + h = h + attn_out + h = h + tl2_mlp(h) + + # Layer 3: individual blocks with DPA + residual = h + qkv = tl3_qkv(h).view(S, B, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn = tl3_dpa(q, k, v).view(S, B, H) + h = residual + tl3_proj(attn) + residual = h + h = residual + tl3_fc2(torch.nn.functional.gelu(tl3_fc1(h))) + + loss = h.float().sum() + loss.backward() + + # ------------------------------------------------------------------ + # Assertions + # ------------------------------------------------------------------ + + assert inp.grad is not None, "Input gradient is None" + + # -- Name propagation check -- + # The factory dispatches on role.name, so if a TE module fails to propagate + # names (e.g. TransformerLayer -> MHA -> LayerNormLinear) the factory would + # silently fall through to the default recipe. The quantizer-type assertions + # below would catch that too, but checking names explicitly gives a clearer + # error message pointing at the broken name rather than a wrong quantizer type. + role_names = {r.name for r in recorded_roles if r is not None} + + def _tl_names(prefix): + """Expected role names for a standard TransformerLayer with given prefix.""" + return { + f"{prefix}.self_attention.layernorm_linear_qkv", + f"{prefix}.self_attention.proj", + f"{prefix}.layernorm_mlp.fc1", + f"{prefix}.layernorm_mlp.fc2", + } + + all_expected = ( + _tl_names("tl0") | _tl_names("tl1") | _tl_names("tl2") + | {"tl3.qkv", "tl3.proj", "tl3.fc1", "tl3.fc2"} + ) + missing = all_expected - role_names + assert not missing, ( + f"Expected module names not seen in QuantizerRole.name: {missing}\n" + f"Recorded names: {sorted(role_names)}" + ) + + for r in recorded_roles: + if r is not None and r.module_type: + assert r.module_type == "linear", ( + f"Unexpected module_type={r.module_type} for role {r}" + ) + + # -- Quantizer-type checks -- + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer + + def _check_q(mod, expected_cls, label=""): + q = mod.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + assert isinstance(q, expected_cls), ( + f"{mod.name}{' (' + label + ')' if label else ''}: " + f"expected {expected_cls.__name__}, got {type(q).__name__}" + ) + + # Layer 0: all MXFP8 + _check_q(tl0.self_attention.layernorm_qkv, MXFP8Quantizer) + _check_q(tl0.self_attention.proj, MXFP8Quantizer) + + # Layer 1: NVFP4 default, fc2 overridden to MXFP8 + _check_q(tl1.self_attention.layernorm_qkv, NVFP4Quantizer, "default") + _check_q(tl1.self_attention.proj, NVFP4Quantizer, "default") + assert any( + r is not None and r.name == "tl1.layernorm_mlp.fc2" and r.tensor_type == "input" + for r in recorded_roles + ), "tl1.layernorm_mlp.fc2 input role not recorded" + + # Layer 2: block-scaling on qkv and fc1, NVFP4 on proj and fc2 + _check_q(tl2_mha.layernorm_qkv, Float8BlockQuantizer) + _check_q(tl2_mha.proj, NVFP4Quantizer, "default") + + # Layer 3: current-scaling on proj, NVFP4 on everything else + _check_q(tl3_proj, Float8CurrentScalingQuantizer) + for mod in [tl3_qkv, tl3_fc1, tl3_fc2]: + _check_q(mod, NVFP4Quantizer, "default") From 343f653d9b53a56a48db3b9dc2670b9e81bf1067 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:18:50 +0000 Subject: [PATCH 25/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_custom_recipe.py | 46 ++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 3855f317c6..612e254bf9 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -488,28 +488,46 @@ def test_custom_recipe_quantization_targets(): torch.manual_seed(42) - H = 64 # hidden_size - FFN = 64 # ffn_hidden_size - NH = 4 # num_heads - KV = H // NH # kv_channels - B = 4 # batch - S = 8 # seq_len + H = 64 # hidden_size + FFN = 64 # ffn_hidden_size + NH = 4 # num_heads + KV = H // NH # kv_channels + B = 4 # batch + S = 8 # seq_len common = dict(params_dtype=torch.bfloat16, bias=False) # Layer 0: TransformerLayer -> MXFP8 tl0 = te.TransformerLayer( - H, FFN, NH, hidden_dropout=0.0, attention_dropout=0.0, name="tl0", **common, + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl0", + **common, ).cuda() # Layer 1: TransformerLayer -> NVFP4 default, fc2 overridden to MXFP8 tl1 = te.TransformerLayer( - H, FFN, NH, hidden_dropout=0.0, attention_dropout=0.0, name="tl1", **common, + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl1", + **common, ).cuda() # Layer 2: MHA + LayerNormMLP -> NVFP4 default, qkv and fc1 to block-scaling tl2_mha = te.MultiheadAttention( - H, NH, KV, attention_dropout=0.0, input_layernorm=True, return_bias=True, - name="tl2.self_attention", **common, + H, + NH, + KV, + attention_dropout=0.0, + input_layernorm=True, + return_bias=True, + name="tl2.self_attention", + **common, ).cuda() tl2_mlp = LayerNormMLP(H, FFN, name="tl2.layernorm_mlp", **common).cuda() @@ -606,7 +624,9 @@ def _tl_names(prefix): } all_expected = ( - _tl_names("tl0") | _tl_names("tl1") | _tl_names("tl2") + _tl_names("tl0") + | _tl_names("tl1") + | _tl_names("tl2") | {"tl3.qkv", "tl3.proj", "tl3.fc1", "tl3.fc2"} ) missing = all_expected - role_names @@ -617,9 +637,7 @@ def _tl_names(prefix): for r in recorded_roles: if r is not None and r.module_type: - assert r.module_type == "linear", ( - f"Unexpected module_type={r.module_type} for role {r}" - ) + assert r.module_type == "linear", f"Unexpected module_type={r.module_type} for role {r}" # -- Quantizer-type checks -- from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer From 6c0381068179d9e62ed089a133763874556ee55a Mon Sep 17 00:00:00 2001 From: Evgeny Date: Mon, 2 Mar 2026 12:58:54 +0000 Subject: [PATCH 26/36] Add quantizer roles for attention (attn is wip) Signed-off-by: Evgeny --- .../dot_product_attention.py | 50 +++++++++++++++++++ .../pytorch/attention/multi_head_attention.py | 30 ++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 875ba21fab..cb012377b8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -22,6 +22,7 @@ ) from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.quantization import ( + QuantizerRole, get_fp8_te_dtype, FP8GlobalStateManager, RecipeState, @@ -805,6 +806,55 @@ def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> Non for recipe_state in recipe_states: self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``DotProductAttention``. + + Quantizer positions follow the GEMM-slot convention used by the + fused-attention kernels: + + Forward (3 GEMMs x 3 = 9 slots): + GEMM1 -> QKV (at ``GEMM1_OUTPUT``), + GEMM2 -> O (at ``GEMM2_INPUT``), + GEMM3 -> S (at ``GEMM3_OUTPUT``). + + Backward (3 GEMMs x 2 = 6 slots): + GEMM1 -> dQKV (at ``GRAD_OUTPUT1``), + GEMM2 -> dO (at ``GRAD_INPUT2``), + GEMM3 -> dP (at ``GRAD_INPUT3``). + + Unused positions in each GEMM group share the role of the + group's primary tensor. + + The O (fwd) and dQKV (bwd) slots mirror the output / grad-input + pattern from linear modules. Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + qkv_role = QuantizerRole(module_type="dpa", tensor_type="qkv", name=name) + o_role = self._output_quantizer_role + s_role = QuantizerRole(module_type="dpa", tensor_type="s", name=name) + base = [ + qkv_role, qkv_role, qkv_role, # GEMM1: QKV at GEMM1_OUTPUT + o_role, o_role, o_role, # GEMM2: O at GEMM2_INPUT + s_role, s_role, s_role, # GEMM3: S at GEMM3_OUTPUT + ] + else: + dqkv_role = self._grad_input_quantizer_role + do_role = QuantizerRole(module_type="dpa", tensor_type="do", name=name) + dp_role = QuantizerRole(module_type="dpa", tensor_type="dp", name=name) + base = [ + dqkv_role, dqkv_role, # GEMM1: dQKV at GRAD_OUTPUT1 + do_role, do_role, # GEMM2: dO at GRAD_INPUT2 + dp_role, dp_role, # GEMM3: dP at GRAD_INPUT3 + ] + return base[:num_quantizers] + @no_torch_dynamo(recursive=False) def forward( self, diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 78a1c4dde6..2378e10396 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -483,13 +483,19 @@ def _update_output_quantizer_roles( self, qkv_fp8_output: bool, proj_fp8_grad: bool, + dpa_fp8_output: bool, ) -> None: - """Set output / grad-input quantizer roles on QKV and proj linears. + """Set output / grad-input quantizer roles on QKV, proj, and DPA. When the QKV linear's output feeds directly into DPA (``fp8_mha``), the role is switched from the default linear-consumer assumption to DPA-consumer roles. Otherwise roles are reset to ``None`` so the modules fall back to their defaults. + + Symmetrically, when DPA produces FP8 output / gradients, its + ``output_quantizer_role`` (O -> proj) and + ``grad_input_quantizer_role`` (dQKV -> QKV linear) are set to + describe the consuming linear module. """ dpa_name = self.core_attention.name or "" qkv_output_role = ( @@ -515,6 +521,26 @@ def _update_output_quantizer_roles( self.key_value.output_quantizer_role = qkv_output_role self.proj.grad_input_quantizer_role = proj_grad_input_role + # DPA boundary roles: O -> proj (fwd), dQKV -> QKV linear (bwd) + proj_name = self.proj.name or "" + self.core_attention.output_quantizer_role = ( + QuantizerRole(module_type="linear", tensor_type="input", name=proj_name) + if dpa_fp8_output + else None + ) + if self.attention_type == "self": + qkv_linear = self.layernorm_qkv if self.input_layernorm else self.qkv + else: + qkv_linear = ( + self.layernorm_query if self.input_layernorm else self.query_layer + ) + qkv_name = qkv_linear.name or "" + self.core_attention.grad_input_quantizer_role = ( + QuantizerRole(module_type="linear", tensor_type="grad_output", name=qkv_name) + if dpa_fp8_output + else None + ) + def _create_qk_norm_modules( self, qk_norm_type: Optional[str], @@ -832,7 +858,7 @@ def forward( # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling - self._update_output_quantizer_roles(qkv_fp8_output, proj_fp8_grad) + self._update_output_quantizer_roles(qkv_fp8_output, proj_fp8_grad, dpa_fp8_output) layernorm_output = None if self.attention_type == "self": From 9b0e497bb444b62a9abc495d90e89fbebd165222 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 13:00:12 +0000 Subject: [PATCH 27/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention.py | 21 +++++++++++++------ .../pytorch/attention/multi_head_attention.py | 4 +--- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index cb012377b8..8808e84482 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -840,18 +840,27 @@ def get_quantizer_roles( o_role = self._output_quantizer_role s_role = QuantizerRole(module_type="dpa", tensor_type="s", name=name) base = [ - qkv_role, qkv_role, qkv_role, # GEMM1: QKV at GEMM1_OUTPUT - o_role, o_role, o_role, # GEMM2: O at GEMM2_INPUT - s_role, s_role, s_role, # GEMM3: S at GEMM3_OUTPUT + qkv_role, + qkv_role, + qkv_role, # GEMM1: QKV at GEMM1_OUTPUT + o_role, + o_role, + o_role, # GEMM2: O at GEMM2_INPUT + s_role, + s_role, + s_role, # GEMM3: S at GEMM3_OUTPUT ] else: dqkv_role = self._grad_input_quantizer_role do_role = QuantizerRole(module_type="dpa", tensor_type="do", name=name) dp_role = QuantizerRole(module_type="dpa", tensor_type="dp", name=name) base = [ - dqkv_role, dqkv_role, # GEMM1: dQKV at GRAD_OUTPUT1 - do_role, do_role, # GEMM2: dO at GRAD_INPUT2 - dp_role, dp_role, # GEMM3: dP at GRAD_INPUT3 + dqkv_role, + dqkv_role, # GEMM1: dQKV at GRAD_OUTPUT1 + do_role, + do_role, # GEMM2: dO at GRAD_INPUT2 + dp_role, + dp_role, # GEMM3: dP at GRAD_INPUT3 ] return base[:num_quantizers] diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 2378e10396..14216cb337 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -531,9 +531,7 @@ def _update_output_quantizer_roles( if self.attention_type == "self": qkv_linear = self.layernorm_qkv if self.input_layernorm else self.qkv else: - qkv_linear = ( - self.layernorm_query if self.input_layernorm else self.query_layer - ) + qkv_linear = self.layernorm_query if self.input_layernorm else self.query_layer qkv_name = qkv_linear.name or "" self.core_attention.grad_input_quantizer_role = ( QuantizerRole(module_type="linear", tensor_type="grad_output", name=qkv_name) From 1d630848e9d558dd14903f554fbc90851d04ea3e Mon Sep 17 00:00:00 2001 From: Evgeny Date: Mon, 2 Mar 2026 14:28:28 +0000 Subject: [PATCH 28/36] Enable statful recipes in the Custom recipe - Delayed Scaling support Signed-off-by: Evgeny --- tests/pytorch/test_custom_recipe.py | 162 ++++++++++++++ transformer_engine/common/recipe/__init__.py | 28 ++- transformer_engine/pytorch/__init__.py | 2 + .../quantization_recipes_base.py | 19 ++ transformer_engine/pytorch/module/base.py | 18 +- transformer_engine/pytorch/quantization.py | 204 ++++++++++++++++-- 6 files changed, 396 insertions(+), 37 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 612e254bf9..8cbf7fb0ac 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -23,6 +23,7 @@ mxfp8_quantizer_factory, float8_block_scaling_quantizer_factory, nvfp4_quantizer_factory, + delayed_scaling_quantizer_factory, ) from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import ( nvfp4_ref_rht_2d_quantizer_factory, @@ -383,6 +384,24 @@ def _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus): ) +def test_factory_matches_delayed_scaling(): + """delayed_scaling_quantizer_factory should produce bit-identical results + to the built-in DelayedScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.DelayedScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=delayed_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + def test_factory_matches_current_scaling(): """current_scaling_quantizer_factory should produce bit-identical results to the built-in Float8CurrentScaling recipe.""" @@ -671,3 +690,146 @@ def _check_q(mod, expected_cls, label=""): _check_q(tl3_proj, Float8CurrentScalingQuantizer) for mod in [tl3_qkv, tl3_fc1, tl3_fc2]: _check_q(mod, NVFP4Quantizer, "default") + + +def test_delayed_scaling_request_wiring(): + """Shared buffers, correct views, Float8Quantizer instances.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID, amax_history_len=16) + + custom_recipe = recipe.CustomRecipe(qfactory=ds_factory) + + # 3 quantizers (input, weight, output) like a Linear fwd + state = CustomRecipeState(custom_recipe, mode="forward", num_quantizers=3) + state.roles = [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + quantizers = state.make_quantizers() + + # All quantizers should be Float8Quantizer + assert len(quantizers) == 3 + for q in quantizers: + assert isinstance(q, Float8Quantizer), f"Expected Float8Quantizer, got {type(q).__name__}" + + # Managed state should exist + assert state._has_delayed_scaling + assert state.scale is not None + assert state.amax_history is not None + + # Shared buffers: scale shape = (3,), amax_history shape = (16, 3) + assert state.scale.shape == (3,) + assert state.amax_history.shape == (16, 3) + + # Each quantizer's scale should be a view into the shared buffer + for i, q in enumerate(quantizers): + assert q.scale.data_ptr() == state.scale[i].data_ptr() + + # Each quantizer's amax should be a view into amax_history[0] + for i, q in enumerate(quantizers): + assert q.amax.data_ptr() == state.amax_history[0][i].reshape((1,)).data_ptr() + + # Inner recipe should be a DelayedScaling + inner = state._inner_delayed_scaling_recipe + assert isinstance(inner, recipe.DelayedScaling) + assert inner.amax_history_len == 16 + assert inner.fp8_format == Format.HYBRID + + +def test_custom_recipe_mixed_ds_and_stateless(): + """Mix DelayedScalingRequest + stateless quantizers in same CustomRecipeState.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + from transformer_engine.common.recipe import Format + + def mixed_factory(role): + # Only weight gets delayed scaling, rest get current scaling + if role is not None and role.tensor_type == "weight": + return DelayedScalingRequest(fp8_format=Format.HYBRID) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=mixed_factory) + + # 3 quantizers: input(current), weight(DS), output(current) + state = CustomRecipeState(custom_recipe, mode="forward", num_quantizers=3) + state.roles = [ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ] + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + + # Slot 0 (input): current scaling + assert isinstance(quantizers[0], Float8CurrentScalingQuantizer) + # Slot 1 (weight): delayed scaling + assert isinstance(quantizers[1], Float8Quantizer) + # Slot 2 (output): current scaling + assert isinstance(quantizers[2], Float8CurrentScalingQuantizer) + + # Only 1 DS request => shared buffers have size 1 + assert state._has_delayed_scaling + assert state.scale.shape == (1,) + assert state.amax_history.shape == (1024, 1) + + +def test_custom_recipe_ds_multi_step(): + """amax_history updates across multiple forward steps.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID) + + in_features = 128 + out_features = 128 + batch = 32 + num_steps = 3 + + torch.manual_seed(99) + model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + custom = recipe.CustomRecipe(qfactory=ds_factory) + + amax_snapshots = [] + for step in range(num_steps): + inp = torch.randn( + batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + with autocast(enabled=True, recipe=custom): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Capture amax_history snapshot + fwd_state = model.fp8_meta["scaling_fwd"] + amax_snapshots.append(fwd_state.amax_history.clone()) + + # After 3 steps, amax_history should have been updated at least once + # The first row (amax_history[0]) should differ from the initial zeros + # after the first step + assert not torch.all(amax_snapshots[0] == 0), ( + "amax_history should be updated after first step" + ) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1db8110ae3..01a741378e 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -499,23 +499,31 @@ class CustomRecipe(Recipe): Parameters ---------- qfactory : Callable - Factory callable that returns a quantizer instance for a given `QuantizerRole`. + Factory callable that returns a quantizer instance *or* a + ``QuantizerRequest`` subclass for a given ``QuantizerRole``. The callable is invoked as:: qfactory( role: QuantizerRole, - ) -> Optional[Quantizer] + ) -> Union[Quantizer, QuantizerRequest] - `QuantizerRole` is a frozen dataclass with the following fields: + ``QuantizerRole`` is a frozen dataclass with the following fields: - - `module_type` (str): module type (empty string when not set), e.g. - `"linear"`, `"grouped_linear"`, `"dpa"`. - - `tensor_type` (str): what tensor is being quantized (empty - string when not set), e.g. `"input"`, `"weight"`, `"grad_output"`. - - `name` (str): caller-provided module instance name (empty - string when not set), e.g. `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`. + - ``module_type`` (str): module type (empty string when not set), e.g. + ``"linear"``, ``"grouped_linear"``, ``"dpa"``. + - ``tensor_type`` (str): what tensor is being quantized (empty + string when not set), e.g. ``"input"``, ``"weight"``, ``"grad_output"``. + - ``name`` (str): caller-provided module instance name (empty + string when not set), e.g. ``"qkv"``, ``"proj"``, ``"fc1"``, ``"fc2"``. - See `transformer_engine.pytorch.quantization.QuantizerRole` + For stateful quantizers (delayed scaling), return a + ``DelayedScalingRequest`` dataclass instead of a quantizer. + TE will allocate shared scale/amax_history buffers and create + ``Float8Quantizer`` instances integrated with the existing + delayed-scaling reduction infrastructure. + + See ``transformer_engine.pytorch.quantization.QuantizerRole`` + and ``transformer_engine.pytorch.quantization.DelayedScalingRequest`` for full documentation. """ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 4880959546..44907dd658 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -49,6 +49,8 @@ from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe from transformer_engine.pytorch.quantization import QuantizerRole +from transformer_engine.pytorch.quantization import QuantizerRequest +from transformer_engine.pytorch.quantization import DelayedScalingRequest from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import is_bf16_available diff --git a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py index 0823a4c7bf..c79f0a4b9b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py @@ -33,6 +33,25 @@ from transformer_engine.pytorch.quantization import QuantizerRole +def delayed_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "DelayedScalingRequest": + """Factory that mirrors :class:`DelayedScaling` recipe defaults. + + Returns a :class:`DelayedScalingRequest` for every slot. TE allocates + shared scale/amax_history buffers and wires them into the existing + delayed-scaling reduction path. + + * HYBRID format: E4M3 forward, E5M2 backward + * amax_history_len = 1024 + * reduce_amax = True + """ + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.common.recipe import Format + + return DelayedScalingRequest(fp8_format=Format.HYBRID) + + def current_scaling_quantizer_factory( role: Optional[QuantizerRole], ) -> "Float8CurrentScalingQuantizer": diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1c5176834d..05793e95c8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -27,9 +27,11 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, NVFP4BlockScalingRecipeState, + CustomRecipeState, FP8GlobalStateManager, QuantizerRole, RecipeState, + _has_delayed_scaling_state, ) from ..distributed import ( gather_along_first_dim, @@ -795,6 +797,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return + if recipe.custom() and isinstance(recipe_state, CustomRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -932,7 +936,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Copy tensors to CPU and store state = {} state["recipe"] = self.fp8_meta["recipe"] - if state["recipe"].delayed(): + if _has_delayed_scaling_state(self.fp8_meta): state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) @@ -1004,7 +1008,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load tensors - if self.fp8_meta["recipe"].delayed(): + if _has_delayed_scaling_state(self.fp8_meta): copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) @@ -1131,7 +1135,7 @@ def prepare_forward( # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): - delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = _has_delayed_scaling_state(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -1143,10 +1147,12 @@ def prepare_forward( self.init_fp8_metadata(num_gemms=num_gemms) self._check_weight_tensor_recipe_correspondence() - delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = self.fp8 and _has_delayed_scaling_state(self.fp8_meta) if delayed_scaling_recipe: if self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( + assert self.fp8_meta["recipe"].reduce_amax or ( + self.fp8_meta["recipe"].custom() + ), ( "Amax reduction across tensor parallel group is " "necessary when using sequence parallelism with FP8." ) @@ -1168,7 +1174,7 @@ def end_forward(self): Required to be called at the end of the forward function to properly handle DelayedScaling metadata handling and the NVTX ranges. """ - delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = self.fp8 and _has_delayed_scaling_state(self.fp8_meta) if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) nvtx_range_pop() diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index ff3c0b9a1c..16b920036d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -43,6 +43,8 @@ "get_default_recipe", "get_align_size_for_quantization", "QuantizerRole", + "QuantizerRequest", + "DelayedScalingRequest", ] @@ -89,6 +91,56 @@ def __str__(self) -> str: return "|".join(parts) if parts else "QuantizerRole()" +@dataclasses.dataclass(frozen=True) +class QuantizerRequest: + """Base class for stateful quantizer requests. + + Custom recipe factories return ``QuantizerRequest`` subclasses (instead of + quantizer instances) when the quantizer requires TE-managed shared state. + TE detects these requests, allocates the required state, and replaces them + with real quantizer instances. + + .. warning:: + **EXPERIMENTAL**: QuantizerRequest is experimental, still under active + development, and the API is subject to change without notice. + """ + + +@dataclasses.dataclass(frozen=True) +class DelayedScalingRequest(QuantizerRequest): + """Request a Float8Quantizer with TE-managed delayed scaling state. + + .. warning:: + **EXPERIMENTAL**: DelayedScalingRequest is experimental, still under active + development, and the API is subject to change without notice. + + All ``DelayedScalingRequest`` instances within the same ``CustomRecipeState`` + must share identical parameter values. + + Parameters + ---------- + fp8_format : Format, default = Format.HYBRID + Controls fwd/bwd dtype (HYBRID = E4M3 fwd, E5M2 bwd). + margin : int, default = 0 + Margin for scaling factor computation. + amax_history_len : int, default = 1024 + Length of the amax history window. + amax_compute_algo : str or Callable, default = "max" + Algorithm for choosing amax from history. + scaling_factor_compute_algo : Callable or None, default = None + Custom scaling factor computation. + reduce_amax : bool, default = True + Whether to all-reduce amax across the distributed group. + """ + + fp8_format: Format = Format.HYBRID + margin: int = 0 + amax_history_len: int = 1024 + amax_compute_algo: Union[str, Callable] = "max" + scaling_factor_compute_algo: Optional[Callable] = None + reduce_amax: bool = True + + @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" @@ -407,7 +459,7 @@ def add_fp8_tensors_to_global_buffer( fp8_meta: Dict[str, Any], ) -> None: """ - Delayed scaling only. + Delayed scaling only (built-in or custom recipe with DS requests). The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is @@ -422,8 +474,8 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + # noop unless delayed scaling state is present + if not _has_delayed_scaling_state(fp8_meta): return # Every module must call this function exactly once since @@ -440,18 +492,32 @@ def add_fp8_tensors_to_global_buffer( # Handles non-parameter FP8 modules, e.g. DPA. continue - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) + state = fp8_meta[fp8_meta_tensor_key] + + # Determine recipe + buffers: built-in DS or custom with DS requests + if isinstance(state, CustomRecipeState) and state._has_delayed_scaling: + inner_recipe = state._inner_delayed_scaling_recipe + amax_hist = state.amax_history + scale_tensor = state.scale + key = cls.get_key_in_buffer(forward, inner_recipe, fp8_meta["fp8_group"]) + # Register inner recipe in autocast_arguments for reduction + autocast_key = cls.get_unique_autocast_key(inner_recipe, fp8_meta["fp8_group"]) + cls.autocast_arguments[autocast_key] = (inner_recipe, fp8_meta["fp8_group"]) + else: + amax_hist = state.amax_history + scale_tensor = state.scale + key = cls.get_key_in_buffer( + forward, fp8_meta["recipe"], fp8_meta["fp8_group"] + ) if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + cls.global_amax_buffer[key] = [amax_hist[0]] + cls.global_amax_history_buffer[key] = [amax_hist] + cls.global_scale_buffer[key] = [scale_tensor] else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history - ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + cls.global_amax_buffer[key].append(amax_hist[0]) + cls.global_amax_history_buffer[key].append(amax_hist) + cls.global_scale_buffer[key].append(scale_tensor) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @@ -661,7 +727,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - """ # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -687,7 +753,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non 1 forward for indentical numerical outputs. """ # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return # Store updated amaxes and scales from phase 1 post forward. @@ -706,7 +772,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) @@ -1404,13 +1470,92 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: raise RuntimeError(f"Unexpected recipe mode ({self.mode})") +def _handle_delayed_scaling_requests( + raw: list, + device: torch.device, + mode: str, +) -> Optional["DelayedScalingRecipeState"]: + """Detect DelayedScalingRequest items, allocate shared state, replace with real quantizers. + + All DS requests in the same RecipeState must share identical parameters. + + Returns a ``DelayedScalingRecipeState`` owning the shared buffers, or + ``None`` when no DS requests are present. + """ + ds_items = [(i, r) for i, r in enumerate(raw) if isinstance(r, DelayedScalingRequest)] + if not ds_items: + return None + + r0 = ds_items[0][1] + + # Validate all DS requests share same params + for idx, req in ds_items[1:]: + for field_name in ( + "fp8_format", + "margin", + "amax_history_len", + "amax_compute_algo", + "scaling_factor_compute_algo", + "reduce_amax", + ): + v0 = getattr(r0, field_name) + vi = getattr(req, field_name) + if v0 != vi: + raise ValueError( + f"All DelayedScalingRequests in one CustomRecipeState must match. " + f"Slot 0 has {field_name}={v0!r}, slot {idx} has {vi!r}." + ) + + # Build a real DelayedScalingRecipeState to own the shared buffers. + inner_recipe = DelayedScaling( + fp8_format=r0.fp8_format, + margin=r0.margin, + amax_history_len=r0.amax_history_len, + amax_compute_algo=r0.amax_compute_algo, + scaling_factor_compute_algo=r0.scaling_factor_compute_algo, + reduce_amax=r0.reduce_amax, + ) + n = len(ds_items) + dsrs = DelayedScalingRecipeState( + inner_recipe, mode=mode, num_quantizers=n, device=device, + ) + + # Splice Float8Quantizer instances (backed by dsrs buffers) into raw list. + quantizers = dsrs.make_quantizers() + for j, (idx, _req) in enumerate(ds_items): + raw[idx] = quantizers[j] + + return dsrs + + +def _has_delayed_scaling_state(fp8_meta: Dict[str, Any]) -> bool: + """Check if fp8_meta has delayed scaling state (built-in or custom).""" + if fp8_meta["recipe"].delayed(): + return True + if fp8_meta["recipe"].custom(): + for key in ("scaling_fwd", "scaling_bwd"): + state = fp8_meta.get(key) + if isinstance(state, CustomRecipeState) and state._has_delayed_scaling: + return True + return False + + class CustomRecipeState(RecipeState): - """State for CustomRecipe: produce quantizers per tensor.""" + """State for CustomRecipe: produce quantizers per tensor. + + Stateful quantizer support: + - Supports stateful quantizers (e.g. delayed scaling) via ``DelayedScalingRequest``. + - The factory returns request dataclasses for stateful quantizers; TE detects them, + allocates shared buffers, and replaces with real quantizer instances. + - Stateful recipe state is composed via real TE recipe state objects (e.g. + ``DelayedScalingRecipeState``), not reimplemented. + """ recipe: CustomRecipe mode: str num_quantizers: int device: Optional[torch.device] + _ds_state: Optional[DelayedScalingRecipeState] def __init__( self, @@ -1426,6 +1571,7 @@ def __init__( if device is None: device = torch.device("cuda") self.device = device + self._ds_state = None if getattr(recipe, "qfactory", None) is None: raise ValueError("CustomRecipe requires `qfactory`.") @@ -1448,8 +1594,24 @@ def make_quantizers(self) -> list: f"({len(roles)=} vs {self.num_quantizers=})" ) - out = [] - for i in range(self.num_quantizers): - quantizer = qfactory(roles[i]) - out.append(quantizer) - return out + raw = [qfactory(roles[i]) for i in range(self.num_quantizers)] + self._ds_state = _handle_delayed_scaling_requests(raw, self.device, self.mode) + return raw + + # -- Delegation to composed DelayedScalingRecipeState -- + + @property + def _has_delayed_scaling(self) -> bool: + return self._ds_state is not None + + @property + def amax_history(self) -> Optional[torch.Tensor]: + return self._ds_state.amax_history if self._ds_state else None + + @property + def scale(self) -> Optional[torch.Tensor]: + return self._ds_state.scale if self._ds_state else None + + @property + def _inner_delayed_scaling_recipe(self) -> Optional[DelayedScaling]: + return self._ds_state.recipe if self._ds_state else None From c0c78ea3a6082e610316debd3ee3d71af5f46372 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Mon, 2 Mar 2026 15:15:38 +0000 Subject: [PATCH 29/36] Fix save_original_input for custom delayed scaling Signed-off-by: Evgeny --- .../pytorch/module/grouped_linear.py | 15 +++++++++++++- transformer_engine/pytorch/module/linear.py | 20 +++++++++++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e5a09241a1..08e6157116 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -107,7 +107,20 @@ def forward( # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): - raise ValueError("DelayedScaling recipe is not supported with save_original_input") + if module.fp8_meta["recipe"].custom(): + # Custom recipe factory may produce DS quantizers unknown to caller. + # TODO(negvet): fix on Megatron side — guard should also exclude 'custom', or + # better: check at runtime whether quantizers are DS-based. + warnings.warn( + "save_original_input is incompatible with delayed-scaling quantizers " + "(Float8Quantizer). Disabling save_original_input for this module.", + stacklevel=2, + ) + save_original_input = False + else: + raise ValueError( + "DelayedScaling recipe is not supported with save_original_input" + ) if input_quantizers[0] is not None: for input_quantizer in input_quantizers: input_quantizer.set_usage( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index adf44208e5..9db859a3bc 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -176,10 +176,22 @@ def forward( if fp8: assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) - if save_original_input: - assert not isinstance( - input_quantizer, Float8Quantizer - ), "DelayedScaling recipe is not supported with save_original_input" + if save_original_input and isinstance(input_quantizer, Float8Quantizer): + if module.fp8_meta["recipe"].custom(): + # Custom recipe factory may produce DS quantizers unknown to caller. + # TODO(negvet): fix on Megatron side — guard in attention.py checks + # `fp8_recipe != 'delayed'` but should also exclude 'custom', or + # better: check at runtime whether quantizers are DS-based. + warnings.warn( + "save_original_input is incompatible with delayed-scaling quantizers " + "(Float8Quantizer). Disabling save_original_input for this module.", + stacklevel=2, + ) + save_original_input = False + else: + raise AssertionError( + "DelayedScaling recipe is not supported with save_original_input" + ) if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor From e5250aca0a71aa4433c2d8f0d129adae4f00316a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 15:28:58 +0000 Subject: [PATCH 30/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_custom_recipe.py | 8 ++------ transformer_engine/pytorch/module/grouped_linear.py | 4 +--- transformer_engine/pytorch/quantization.py | 11 ++++++----- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 8cbf7fb0ac..4a564145d1 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -393,9 +393,7 @@ def test_factory_matches_delayed_scaling(): model_ref, model_cus, inp_ref, inp_cus = _make_pair() - out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( - model_ref, inp_ref, recipe.DelayedScaling() - ) + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd(model_ref, inp_ref, recipe.DelayedScaling()) out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( model_cus, inp_cus, recipe.CustomRecipe(qfactory=delayed_scaling_quantizer_factory) ) @@ -830,6 +828,4 @@ def ds_factory(role): # After 3 steps, amax_history should have been updated at least once # The first row (amax_history[0]) should differ from the initial zeros # after the first step - assert not torch.all(amax_snapshots[0] == 0), ( - "amax_history should be updated after first step" - ) + assert not torch.all(amax_snapshots[0] == 0), "amax_history should be updated after first step" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 08e6157116..5d77e73f60 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -118,9 +118,7 @@ def forward( ) save_original_input = False else: - raise ValueError( - "DelayedScaling recipe is not supported with save_original_input" - ) + raise ValueError("DelayedScaling recipe is not supported with save_original_input") if input_quantizers[0] is not None: for input_quantizer in input_quantizers: input_quantizer.set_usage( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 16b920036d..c7157dc54d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -506,9 +506,7 @@ def add_fp8_tensors_to_global_buffer( else: amax_hist = state.amax_history scale_tensor = state.scale - key = cls.get_key_in_buffer( - forward, fp8_meta["recipe"], fp8_meta["fp8_group"] - ) + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [amax_hist[0]] @@ -1502,7 +1500,7 @@ def _handle_delayed_scaling_requests( vi = getattr(req, field_name) if v0 != vi: raise ValueError( - f"All DelayedScalingRequests in one CustomRecipeState must match. " + "All DelayedScalingRequests in one CustomRecipeState must match. " f"Slot 0 has {field_name}={v0!r}, slot {idx} has {vi!r}." ) @@ -1517,7 +1515,10 @@ def _handle_delayed_scaling_requests( ) n = len(ds_items) dsrs = DelayedScalingRecipeState( - inner_recipe, mode=mode, num_quantizers=n, device=device, + inner_recipe, + mode=mode, + num_quantizers=n, + device=device, ) # Splice Float8Quantizer instances (backed by dsrs buffers) into raw list. From b8c07029d53b4cd3364670ff6911f9b759987733 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Mon, 2 Mar 2026 18:54:28 +0000 Subject: [PATCH 31/36] Enable custom recipe for attn Signed-off-by: Evgeny --- tests/pytorch/test_custom_recipe.py | 129 +++++++++++++++++- transformer_engine/common/recipe/__init__.py | 5 + .../dot_product_attention.py | 25 +++- .../attention/dot_product_attention/utils.py | 18 +++ .../quantization_factory_examples.py | 88 ++++++++++++ 5 files changed, 261 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 4a564145d1..20fc271964 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -828,4 +828,131 @@ def ds_factory(role): # After 3 steps, amax_history should have been updated at least once # The first row (amax_history[0]) should differ from the initial zeros # after the first step - assert not torch.all(amax_snapshots[0] == 0), "amax_history should be updated after first step" + assert not torch.all(amax_snapshots[0] == 0), ( + "amax_history should be updated after first step" + ) + + +def test_custom_recipe_dpa_fp8(): + """DotProductAttention forward+backward with CustomRecipe and role-based mixed quantizers. + + Uses the nvfp4_linear_fp8_dpa_factory which dispatches: + * DPA S/dP slots -> DelayedScalingRequest (stateful) + * DPA QKV/O/dO/dQKV slots -> Float8CurrentScalingQuantizer + * Linear slots -> NVFP4Quantizer + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.utils import get_device_compute_capability + + cc = get_device_compute_capability() + if cc < (9, 0) or cc >= (12, 0): + pytest.skip(f"FP8 attention not supported on sm{cc[0]*10+cc[1]}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, + ) + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_fp8_dpa_factory, + ) + + torch.manual_seed(42) + + H = 64 + NH = 4 + KV = H // NH + B = 2 + S = 32 + + # Build a small model: Linear -> DPA -> Linear + qkv_proj = Linear(H, 3 * H, params_dtype=torch.bfloat16, bias=False, name="qkv").cuda() + dpa = te.DotProductAttention( + NH, KV, attention_dropout=0.0, qkv_format="bshd", name="core_attention" + ) + out_proj = Linear(H, H, params_dtype=torch.bfloat16, bias=False, name="proj").cuda() + + custom_recipe = recipe.CustomRecipe( + qfactory=nvfp4_linear_fp8_dpa_factory, + fp8_dpa=True, + ) + + inp = torch.randn(B, S, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + qkv = qkv_proj(inp).view(B, S, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn_out = dpa(q, k, v, qkv_format="bshd").reshape(B, S, H) + out = out_proj(attn_out) + + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None, "Input gradient should exist" + + # Verify DPA recipe state is CustomRecipeState + fwd_state = dpa.fp8_meta["scaling_fwd"] + assert isinstance(fwd_state, CustomRecipeState), ( + f"Expected CustomRecipeState for DPA fwd, got {type(fwd_state).__name__}" + ) + + # Verify DPA quantizers: 9 forward slots (3 GEMMs x 3) + fwd_quantizers = dpa.quantizers["scaling_fwd"] + assert len(fwd_quantizers) == 9, f"Expected 9 fwd quantizers, got {len(fwd_quantizers)}" + + # Slots 0-2: QKV (GEMM1) -> current scaling (role: module_type="dpa") + # Slots 3-5: O (GEMM2) -> current scaling (role: name hint "dpa_output") + # Slots 6-8: S (GEMM3) -> delayed scaling (Float8Quantizer from DelayedScalingRequest) + for i in range(6): + assert isinstance(fwd_quantizers[i], Float8CurrentScalingQuantizer), ( + f"Slot {i} (QKV/O): expected Float8CurrentScalingQuantizer, " + f"got {type(fwd_quantizers[i]).__name__}" + ) + for i in range(6, 9): + assert isinstance(fwd_quantizers[i], Float8Quantizer), ( + f"Slot {i} (S): expected Float8Quantizer (delayed scaling), " + f"got {type(fwd_quantizers[i]).__name__}" + ) + + # Verify DS state exists for the S/dP delayed scaling requests + assert fwd_state._has_delayed_scaling, "DPA fwd state should have delayed scaling for S slots" + + # Verify backward quantizers exist too + bwd_quantizers = dpa.quantizers["scaling_bwd"] + assert len(bwd_quantizers) == 6, f"Expected 6 bwd quantizers, got {len(bwd_quantizers)}" + + # Slots 0-1: dQKV (GEMM1) -> current scaling (role: name hint "dpa_grad_input") + # Slots 2-3: dO (GEMM2) -> current scaling (role: module_type="dpa") + # Slots 4-5: dP (GEMM3) -> delayed scaling + for i in range(4): + assert isinstance(bwd_quantizers[i], Float8CurrentScalingQuantizer), ( + f"Bwd slot {i} (dQKV/dO): expected Float8CurrentScalingQuantizer, " + f"got {type(bwd_quantizers[i]).__name__}" + ) + for i in range(4, 6): + assert isinstance(bwd_quantizers[i], Float8Quantizer), ( + f"Bwd slot {i} (dP): expected Float8Quantizer (delayed scaling), " + f"got {type(bwd_quantizers[i]).__name__}" + ) + + # Linear modules should have CustomRecipeState with NVFP4 quantizers + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + qkv_fwd = qkv_proj.fp8_meta["scaling_fwd"] + assert isinstance(qkv_fwd, CustomRecipeState), ( + f"Expected CustomRecipeState for qkv_proj, got {type(qkv_fwd).__name__}" + ) + qkv_fwd_quantizers = qkv_proj.quantizers["scaling_fwd"] + for i, q in enumerate(qkv_fwd_quantizers): + if q is not None: + assert isinstance(q, NVFP4Quantizer), ( + f"qkv_proj fwd slot {i}: expected NVFP4Quantizer, got {type(q).__name__}" + ) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 01a741378e..c6cae2efcb 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -529,6 +529,11 @@ class CustomRecipe(Recipe): qfactory: Callable[..., Any] + # fp8_format does not affect quantization (quantization factory controls that), + # but TE internals (e.g. get_fp8_te_dtype, backend selection) read it + # from the recipe. HYBRID (E4M3 fwd, E5M2 bwd) is a safe default. + fp8_format: Format = Format.HYBRID + fp8_dpa: bool = False fp8_mha: bool = False diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 8808e84482..d6271810b2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -588,6 +588,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_recipe.custom(): + super().init_fp8_metadata(num_gemms=num_gemms) return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to @@ -760,6 +761,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe) and recipe.custom(): + return TransformerEngineBaseModule.set_meta_tensor(self, fwd, recipe) if isinstance(recipe, Recipe): recipe = [recipe] fp8_recipe_dpa = recipe[-1] @@ -830,14 +833,26 @@ def get_quantizer_roles( Unused positions in each GEMM group share the role of the group's primary tensor. - The O (fwd) and dQKV (bwd) slots mirror the output / grad-input - pattern from linear modules. Set :attr:`output_quantizer_role` / - :attr:`grad_input_quantizer_role` to provide consumer identity. + The O (fwd) and dQKV (bwd) slots are **boundary** tensors whose + consumer is unknown to DPA. Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity + (e.g. from ``MultiheadAttention``). + + When the consumer is not set, a hint-only ``QuantizerRole`` with + ``module_type=""`` and ``tensor_type=""`` is emitted. Its ``name`` + field carries ``".dpa_output"`` or ``".dpa_grad_input"`` + so that the quantizer factory can distinguish DPA boundary slots + from other modules' boundary slots. This is needed because the + fused-attention kernel requires FP8-compatible quantizers in all + slots -- the factory must return an FP8 quantizer for these hints + rather than e.g. NVFP4. """ name = self.name or "" if fwd: qkv_role = QuantizerRole(module_type="dpa", tensor_type="qkv", name=name) o_role = self._output_quantizer_role + if o_role is None: + o_role = QuantizerRole(name=f"{name}.dpa_output" if name else "dpa_output") s_role = QuantizerRole(module_type="dpa", tensor_type="s", name=name) base = [ qkv_role, @@ -852,6 +867,10 @@ def get_quantizer_roles( ] else: dqkv_role = self._grad_input_quantizer_role + if dqkv_role is None: + dqkv_role = QuantizerRole( + name=f"{name}.dpa_grad_input" if name else "dpa_grad_input" + ) do_role = QuantizerRole(module_type="dpa", tensor_type="do", name=name) dp_role = QuantizerRole(module_type="dpa", tensor_type="dp", name=name) base = [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..3d09ee682d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2118,6 +2118,24 @@ def get_attention_quantizers(fp8, quantizers): dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + _fp8_types = (Float8Quantizer, Float8CurrentScalingQuantizer) + for _name, _q in [ + ("QKV", QKV_quantizer), + ("O", O_quantizer), + ("S", S_quantizer), + ("dQKV", dQKV_quantizer), + ("dO", dO_quantizer), + ("dP", dP_quantizer), + ]: + assert isinstance(_q, _fp8_types), ( + f"FP8 attention requires FP8-compatible quantizers for all DPA tensor slots, " + f"but {_name} quantizer is {type(_q).__name__}. " + f"When using CustomRecipe with fp8_dpa=True, ensure the factory returns an " + f"FP8 quantizer (Float8Quantizer or Float8CurrentScalingQuantizer) for all " + f"DPA roles (module_type='dpa') and for None roles (boundary slots like " + f"O output and dQKV grad-input)." + ) + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py index 5c563da8a9..6f5fead4e3 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py @@ -14,11 +14,18 @@ from transformer_engine.pytorch.quantization import autocast from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( nvfp4_linear_mxfp8_grouped_linear_factory, + nvfp4_linear_fp8_dpa_factory, ) + # Mixed module types: NVFP4 for Linear, MXFP8 for GroupedLinear recipe = CustomRecipe(qfactory=nvfp4_linear_mxfp8_grouped_linear_factory) with autocast(recipe=recipe): output = model(input) + + # NVFP4 for Linear, FP8 current-scaling + delayed-scaling for DPA + recipe = CustomRecipe(qfactory=nvfp4_linear_fp8_dpa_factory, fp8_dpa=True) + with autocast(recipe=recipe): + output = model(input) """ from __future__ import annotations @@ -104,3 +111,84 @@ def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): stochastic_rounding=False, with_random_sign_mask=True, ) + + +def nvfp4_linear_fp8_dpa_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, mixed FP8 for ``DotProductAttention``. + + This factory demonstrates how to use ``CustomRecipe`` with ``fp8_dpa=True`` + to combine NVFP4 quantization for linear layers with FP8 attention. + + DPA tensor types (``role.module_type == "dpa"``): + + =========== ============================================================ + tensor_type Description + =========== ============================================================ + ``"qkv"`` Query, Key, Value inputs to the first attention GEMM + ``"s"`` Softmax output (S = softmax(Q·K^T)), fed into the second GEMM + ``"o"`` Attention output (O = S·V) + ``"do"`` Gradient of the attention output (dO), backward input + ``"dp"`` Gradient of the softmax output (dP = dO·V^T), backward + ``"dqkv"`` Gradient flowing back to Q, K, V + =========== ============================================================ + + Dispatch logic: + * ``role.module_type == "dpa"`` with ``tensor_type in ("s", "dp")`` + -> FP8 delayed scaling (stateful amax tracking) + * ``role.module_type == "dpa"`` (QKV, dO) + -> FP8 current scaling (E4M3) + * DPA boundary hints (``"dpa_output"`` / ``"dpa_grad_input"`` in ``role.name``) + -> FP8 current scaling placeholder. The fused attention kernel requires + FP8-compatible quantizers in all DPA slots, even when the output is + produced in BF16 (``fp8_mha=False``). DPA emits these hint-only roles + (with empty ``module_type`` and ``tensor_type``) when the downstream + consumer is unknown. + * everything else (``"linear"`` / ``"grouped_linear"`` / ``None``) + -> NVFP4 (E2M1), configured per tensor role + + Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_fp8_dpa_factory, + ) + + recipe = CustomRecipe( + qfactory=nvfp4_linear_fp8_dpa_factory, + fp8_dpa=True, + ) + with autocast(recipe=recipe): + output = model(input) + """ + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer + + is_dpa = role is not None and role.module_type == "dpa" + is_softmax_or_dp = is_dpa and role.tensor_type in ("s", "dp") + + if is_softmax_or_dp: + return DelayedScalingRequest() + + if is_dpa: + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + + # DPA boundary slots (O output / dQKV grad-input): the fused attention + # kernel only supports FP8 quantizers here, regardless of the linear recipe. + is_dpa_boundary = ( + role is not None + and not role.module_type + and ("dpa_output" in role.name or "dpa_grad_input" in role.name) + ) + if is_dpa_boundary: + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + + return _make_nvfp4_quantizer(role) From 488d5e6ab83c72e9e01f1a7ccb723cdb6f6aa0ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:01:21 +0000 Subject: [PATCH 32/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_custom_recipe.py | 22 +++++++++---------- .../attention/dot_product_attention/utils.py | 10 ++++----- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 20fc271964..bfd31dbde3 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -828,9 +828,7 @@ def ds_factory(role): # After 3 steps, amax_history should have been updated at least once # The first row (amax_history[0]) should differ from the initial zeros # after the first step - assert not torch.all(amax_snapshots[0] == 0), ( - "amax_history should be updated after first step" - ) + assert not torch.all(amax_snapshots[0] == 0), "amax_history should be updated after first step" def test_custom_recipe_dpa_fp8(): @@ -900,9 +898,9 @@ def test_custom_recipe_dpa_fp8(): # Verify DPA recipe state is CustomRecipeState fwd_state = dpa.fp8_meta["scaling_fwd"] - assert isinstance(fwd_state, CustomRecipeState), ( - f"Expected CustomRecipeState for DPA fwd, got {type(fwd_state).__name__}" - ) + assert isinstance( + fwd_state, CustomRecipeState + ), f"Expected CustomRecipeState for DPA fwd, got {type(fwd_state).__name__}" # Verify DPA quantizers: 9 forward slots (3 GEMMs x 3) fwd_quantizers = dpa.quantizers["scaling_fwd"] @@ -947,12 +945,12 @@ def test_custom_recipe_dpa_fp8(): from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer qkv_fwd = qkv_proj.fp8_meta["scaling_fwd"] - assert isinstance(qkv_fwd, CustomRecipeState), ( - f"Expected CustomRecipeState for qkv_proj, got {type(qkv_fwd).__name__}" - ) + assert isinstance( + qkv_fwd, CustomRecipeState + ), f"Expected CustomRecipeState for qkv_proj, got {type(qkv_fwd).__name__}" qkv_fwd_quantizers = qkv_proj.quantizers["scaling_fwd"] for i, q in enumerate(qkv_fwd_quantizers): if q is not None: - assert isinstance(q, NVFP4Quantizer), ( - f"qkv_proj fwd slot {i}: expected NVFP4Quantizer, got {type(q).__name__}" - ) + assert isinstance( + q, NVFP4Quantizer + ), f"qkv_proj fwd slot {i}: expected NVFP4Quantizer, got {type(q).__name__}" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3d09ee682d..8ad2376e3e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2128,12 +2128,12 @@ def get_attention_quantizers(fp8, quantizers): ("dP", dP_quantizer), ]: assert isinstance(_q, _fp8_types), ( - f"FP8 attention requires FP8-compatible quantizers for all DPA tensor slots, " + "FP8 attention requires FP8-compatible quantizers for all DPA tensor slots, " f"but {_name} quantizer is {type(_q).__name__}. " - f"When using CustomRecipe with fp8_dpa=True, ensure the factory returns an " - f"FP8 quantizer (Float8Quantizer or Float8CurrentScalingQuantizer) for all " - f"DPA roles (module_type='dpa') and for None roles (boundary slots like " - f"O output and dQKV grad-input)." + "When using CustomRecipe with fp8_dpa=True, ensure the factory returns an " + "FP8 quantizer (Float8Quantizer or Float8CurrentScalingQuantizer) for all " + "DPA roles (module_type='dpa') and for None roles (boundary slots like " + "O output and dQKV grad-input)." ) return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer From 89d14481392524e3114d8b3ff0ebaf6836f58fe3 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Wed, 4 Mar 2026 16:03:52 +0000 Subject: [PATCH 33/36] Make boundary role setting more explicit in MHA Signed-off-by: Evgeny --- tests/pytorch/test_custom_recipe.py | 4 +- .../pytorch/attention/multi_head_attention.py | 48 +++++++++++++------ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index ecea70d895..38e048ac5c 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -650,7 +650,9 @@ def _tl_names(prefix): for r in recorded_roles: if r is not None and r.module_type: - assert r.module_type == "linear", f"Unexpected module_type={r.module_type} for role {r}" + assert r.module_type in ("linear", "dpa"), ( + f"Unexpected module_type={r.module_type} for role {r}" + ) # -- Quantizer-type checks -- from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 8d8a491db3..1607d513b1 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -485,29 +485,38 @@ def _update_output_quantizer_roles( proj_fp8_grad: bool, dpa_fp8_output: bool, ) -> None: - """Set output / grad-input quantizer roles on QKV, proj, and DPA. + """Set quantizer roles at the boundaries between QKV, DPA, and proj. - When the QKV linear's output feeds directly into DPA (``fp8_mha``), - the role is switched from the default linear-consumer assumption to - DPA-consumer roles. Otherwise roles are reset to ``None`` so the - modules fall back to their defaults. + MHA contains three submodules connected as follows:: - Symmetrically, when DPA produces FP8 output / gradients, its - ``output_quantizer_role`` (O -> proj) and - ``grad_input_quantizer_role`` (dQKV -> QKV linear) are set to - describe the consuming linear module. + Forward: QKV linear ──(QKV tensor)──> DPA ──(O tensor)──> Proj linear + Backward: QKV linear <──(dQKV tensor)── DPA <──(dO tensor)── Proj linear + + Each submodule owns quantizers for its internal tensors, but the + *boundary* tensors (the arrows above) need to know which module + will *consume* them so the quantizer factory can pick the right + format. This method sets those boundary roles on all four edges: + + 1. ``qkv_fp8_output`` — **QKV linear → DPA (fwd)**: the QKV + linear's ``output_quantizer_role`` is told its consumer is DPA. + 2. ``proj_fp8_grad`` — **Proj linear ← DPA (bwd)**: proj's + ``grad_input_quantizer_role`` is told its producer is DPA. + 3. ``dpa_fp8_output`` — **DPA → Proj linear (fwd)**: DPA's + ``output_quantizer_role`` is told its consumer is the proj linear. + 4. ``dpa_fp8_output`` — **DPA ← QKV linear (bwd)**: DPA's + ``grad_input_quantizer_role`` is told its consumer is QKV linear. + + When a flag is ``False`` the corresponding role is reset to ``None`` + so the module falls back to its own default. """ dpa_name = self.core_attention.name or "" + + # ── Boundary 1 (fwd): QKV linear output → consumed by DPA ──────── qkv_output_role = ( QuantizerRole(module_type="dpa", tensor_type="qkv", name=dpa_name) if qkv_fp8_output else None ) - proj_grad_input_role = ( - QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) - if proj_fp8_grad - else None - ) if self.attention_type == "self": if self.input_layernorm: self.layernorm_qkv.output_quantizer_role = qkv_output_role @@ -519,15 +528,24 @@ def _update_output_quantizer_roles( else: self.query_layer.output_quantizer_role = qkv_output_role self.key_value.output_quantizer_role = qkv_output_role + + # ── Boundary 2 (bwd): Proj grad-input ← produced by DPA ────────── + proj_grad_input_role = ( + QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) + if proj_fp8_grad + else None + ) self.proj.grad_input_quantizer_role = proj_grad_input_role - # DPA boundary roles: O -> proj (fwd), dQKV -> QKV linear (bwd) + # ── Boundary 3 (fwd): DPA output (O) → consumed by Proj linear ─── proj_name = self.proj.name or "" self.core_attention.output_quantizer_role = ( QuantizerRole(module_type="linear", tensor_type="input", name=proj_name) if dpa_fp8_output else None ) + + # ── Boundary 4 (bwd): DPA grad-input (dQKV) → consumed by QKV linear if self.attention_type == "self": qkv_linear = self.layernorm_qkv if self.input_layernorm else self.qkv else: From f18dc759fa9c2c6195dc4cef0774956c8731a6ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:23:15 +0000 Subject: [PATCH 34/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_custom_recipe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 38e048ac5c..8231d013d3 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -650,9 +650,10 @@ def _tl_names(prefix): for r in recorded_roles: if r is not None and r.module_type: - assert r.module_type in ("linear", "dpa"), ( - f"Unexpected module_type={r.module_type} for role {r}" - ) + assert r.module_type in ( + "linear", + "dpa", + ), f"Unexpected module_type={r.module_type} for role {r}" # -- Quantizer-type checks -- from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer From f21ce2f077a24f1612fca858dc2e50fbd4411ac6 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Wed, 4 Mar 2026 16:34:05 +0000 Subject: [PATCH 35/36] Make dpa role setting more intuitive Signed-off-by: Evgeny --- .../dot_product_attention.py | 86 +++++++++++-------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index d6271810b2..091621fa76 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -817,35 +817,51 @@ def get_quantizer_roles( ) -> Optional[List[QuantizerRole]]: """QuantizerRole list for quantizers used by ``DotProductAttention``. - Quantizer positions follow the GEMM-slot convention used by the - fused-attention kernels: - - Forward (3 GEMMs x 3 = 9 slots): - GEMM1 -> QKV (at ``GEMM1_OUTPUT``), - GEMM2 -> O (at ``GEMM2_INPUT``), - GEMM3 -> S (at ``GEMM3_OUTPUT``). - - Backward (3 GEMMs x 2 = 6 slots): - GEMM1 -> dQKV (at ``GRAD_OUTPUT1``), - GEMM2 -> dO (at ``GRAD_INPUT2``), - GEMM3 -> dP (at ``GRAD_INPUT3``). - - Unused positions in each GEMM group share the role of the - group's primary tensor. - - The O (fwd) and dQKV (bwd) slots are **boundary** tensors whose - consumer is unknown to DPA. Set :attr:`output_quantizer_role` / - :attr:`grad_input_quantizer_role` to provide consumer identity - (e.g. from ``MultiheadAttention``). - - When the consumer is not set, a hint-only ``QuantizerRole`` with - ``module_type=""`` and ``tensor_type=""`` is emitted. Its ``name`` - field carries ``".dpa_output"`` or ``".dpa_grad_input"`` - so that the quantizer factory can distinguish DPA boundary slots - from other modules' boundary slots. This is needed because the - fused-attention kernel requires FP8-compatible quantizers in all - slots -- the factory must return an FP8 quantizer for these hints - rather than e.g. NVFP4. + DPA internally performs two matmuls:: + + S = softmax(Q · K^T) (GEMM1) + O = S · V (GEMM2) + + cuDNN's fused-attention API exposes FP8 scale/amax descriptors as + a flat array of **slot groups** numbered 1-3. The numbering is a + cuDNN convention — it does *not* correspond to operation order + inside DPA: + + Forward (3 slot groups × 3 positions = 9 slots): + + =========== =========================================== =========== + Slot group Primary tensor cuDNN enum + =========== =========================================== =========== + Group 1 QKV — inputs to GEMM1 (Q·K^T) GEMM1_OUTPUT + Group 2 O — output of GEMM2 (S·V) GEMM2_INPUT + Group 3 S — post-softmax, input to GEMM2 (S·V) GEMM3_OUTPUT + =========== =========================================== =========== + + Backward (3 slot groups × 2 positions = 6 slots): + + =========== =========================================== =========== + Slot group Primary tensor cuDNN enum + =========== =========================================== =========== + Group 1 dQKV — gradients flowing back to Q, K, V GRAD_OUTPUT1 + Group 2 dO — gradient of the attention output GRAD_INPUT2 + Group 3 dP — gradient of the softmax output GRAD_INPUT3 + =========== =========================================== =========== + + Unused positions within a group share the role of the group's + primary tensor. + + **Boundary slots** — O (fwd) and dQKV (bwd) leave DPA and enter + the next module (e.g. proj linear). DPA does not know that + consumer, so these default to ``None``. The parent module + (e.g. ``MultiheadAttention``) can set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to fill in the consumer identity. + + When not set, a hint-only ``QuantizerRole`` with empty + ``module_type`` / ``tensor_type`` is emitted, with ``name`` + containing ``"dpa_output"`` or ``"dpa_grad_input"``. This lets + the factory return a DPA-compatible quantizer (required by the + fused kernel) even when the downstream consumer is unknown. """ name = self.name or "" if fwd: @@ -857,13 +873,13 @@ def get_quantizer_roles( base = [ qkv_role, qkv_role, - qkv_role, # GEMM1: QKV at GEMM1_OUTPUT + qkv_role, # Group 1: QKV (inputs to Q·K^T) o_role, o_role, - o_role, # GEMM2: O at GEMM2_INPUT + o_role, # Group 2: O (output of S·V) — boundary s_role, s_role, - s_role, # GEMM3: S at GEMM3_OUTPUT + s_role, # Group 3: S (post-softmax, input to S·V) ] else: dqkv_role = self._grad_input_quantizer_role @@ -875,11 +891,11 @@ def get_quantizer_roles( dp_role = QuantizerRole(module_type="dpa", tensor_type="dp", name=name) base = [ dqkv_role, - dqkv_role, # GEMM1: dQKV at GRAD_OUTPUT1 + dqkv_role, # Group 1: dQKV (grads to Q,K,V) — boundary do_role, - do_role, # GEMM2: dO at GRAD_INPUT2 + do_role, # Group 2: dO (grad of attention output) dp_role, - dp_role, # GEMM3: dP at GRAD_INPUT3 + dp_role, # Group 3: dP (grad of softmax output) ] return base[:num_quantizers] From 0c1ec9bc693c161e4f0b7ecd41cb1255c5a3481c Mon Sep 17 00:00:00 2001 From: Evgeny Date: Wed, 4 Mar 2026 16:54:13 +0000 Subject: [PATCH 36/36] Docstring for get_quantizer_roles() in base module Signed-off-by: Evgeny --- transformer_engine/pytorch/module/base.py | 66 ++++++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8d6d744b4c..5b00e58501 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -828,8 +828,70 @@ def get_quantizer_roles( ) -> Optional[List[QuantizerRole]]: """Return an ordered list of :class:`QuantizerRole` for quantizers. - The returned list must have length `num_quantizers`. - Returning `None` means "no explicit roles". + Overview + -------- + When using ``CustomRecipe``, the quantizer factory is called once + per quantizer slot. Each call receives a ``QuantizerRole`` that + tells the factory *what* is being quantized so it can return the + right quantizer. + + This method builds the role list. Subclasses override it to + describe their internal GEMM layout. + + Slot layout + ----------- + Return one ``QuantizerRole`` per slot, in the same order as the + module's quantizer array. For example, ``Linear`` uses 3 + forward slots ``[input, weight, output]`` and 2 backward slots + ``[grad_output, grad_input]``. Multi-GEMM modules like + ``LayerNormMLP`` repeat that pattern for each GEMM: + ``[fc1_input, fc1_weight, fc1_output, fc2_input, fc2_weight, fc2_output]``. + + What to put in each slot + ------------------------ + Create a ``QuantizerRole(module_type=..., tensor_type=..., + name=...)`` for each slot: + + * **module_type** — the kind of module: ``"linear"``, + ``"grouped_linear"``, ``"dpa"``, etc. The factory can dispatch + on this to use different quantization formats per module type. + * **tensor_type** — what tensor this slot holds, in the module's + own vocabulary. For linears: ``"input"``, ``"weight"``, + ``"grad_output"``, etc. For DPA: ``"qkv"``, ``"s"``, + ``"do"``, ``"dp"``, etc. + * **name** — the instance name (e.g. ``"encoder.layer0.fc1"``), + propagated from the ``name=`` constructor argument. The factory + can dispatch on this to target specific layers. + + Boundary slots + -------------- + The last slot of a forward GEMM group (output) and the last slot + of a backward group (grad_input) are **boundary** slots — the + tensor leaves this module and enters an unknown consumer. For + these slots, use ``self._output_quantizer_role`` (fwd) and + ``self._grad_input_quantizer_role`` (bwd), which default to + ``None``. A parent module (e.g. ``MultiheadAttention``) can set + these attributes to fill in the consumer identity; see + ``MultiheadAttention._update_output_quantizer_roles`` for an + example. + + Return value + ------------ + * A list of ``QuantizerRole`` with length ``num_quantizers``. + * ``None`` to opt out of role-based dispatch. + + Not implemented (default) + ~~~~~~~~~~~~~~~~~~~~~~~~~ + The base implementation returns ``None``. When ``None`` is + returned, ``CustomRecipeState`` emits a warning and falls back + to bare ``QuantizerRole()`` instances (all fields empty) for + every slot. The factory still gets called once per slot, but + every call receives an identical empty role — it cannot + distinguish input from weight, forward from backward, or one + module from another. What happens then depends entirely on the + factory: it may return the same quantizer for all slots (acting + as a uniform recipe), or it may raise an error if it requires + meaningful roles to dispatch on. """ return None