Skip to content

[PyTorch] Introduce quantizer roles#2620

Open
negvet wants to merge 39 commits intoNVIDIA:mainfrom
negvet:semantic_quantizer_roles
Open

[PyTorch] Introduce quantizer roles#2620
negvet wants to merge 39 commits intoNVIDIA:mainfrom
negvet:semantic_quantizer_roles

Conversation

@negvet
Copy link
Collaborator

@negvet negvet commented Jan 23, 2026

Description

Introducing QuantizerRole

@dataclasses.dataclass(frozen=True)
class QuantizerRole:
    module_type: str = ""   # e.g. "linear", "grouped_linear", "dpa"
    tensor_type: str = ""   # e.g. "input", "weight", "grad_output", "qkv", "s"
    name: str = ""          # instance name, e.g. "qkv", "proj", "fc1", "fc2"

This is an API that allows to go down to "set this LayerNormLinear in this transformer layer to be less aggressively quantized." (fine-grained, per-module/per-tensor quantization control mechanism)
See test_custom_recipe.py::test_custom_recipe_quantization_targets().

Quantizer factory uses roles to dispatch according to its needs.

TE module/op emits a list of QuantizerRole:

  • Linear, LayerNormLinear, LayerNormMLP emit module_type="linear" with tensor_type in {"input", "weight", "grad_output"}.
  • GroupedLinear emits module_type="grouped_linear".

CustomRecipe accepts a qfactory callable that receives QuantizerRole and returns a quantizer.

Factories can be composed - e.g., dispatch (to different sub-factories as an option) based on module_type (dpa vs linear) and then refine based on tensor_type.

Summary:

  • Modules implement get_quantizer_roles() that returns a list of QuantizerRole objects.
  • During set_meta_tensor(), modules call get_quantizer_roles() and pass roles to RecipeState.create().
  • RecipeState.create() assigns roles to the state (e.g., CustomRecipeState.roles).
  • CustomRecipeState.make_quantizers() calls qfactory(role) for each role to create quantizers.
  • The factory can inspect role.module_type, role.tensor_type, and role.name to dispatch to different quantizers.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

negvet and others added 4 commits January 23, 2026 15:14
…ipe state

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested review from cyanguwa and timmoon10 January 23, 2026 15:32
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile Summary

This PR introduces QuantizerRole — a frozen dataclass with module_type, tensor_type, and name fields — as the new API for the CustomRecipe.qfactory callable. Previously the factory received hardcoded role strings like "linear_input"; now it receives structured QuantizerRole objects, enabling fine-grained, per-module/per-tensor quantization control (e.g. "use MXFP8 for this specific LayerNormLinear in transformer layer 0, NVFP4 everywhere else"). The PR also adds DelayedScalingRequest support inside CustomRecipe, allowing users to request stateful quantizers by returning a request dataclass instead of a quantizer instance. All TE modules (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear, DotProductAttention) now implement get_quantizer_roles(), and MultiheadAttention wires boundary roles across the QKV→DPA→Proj chain.

Key changes:

  • QuantizerRole, QuantizerRequest, DelayedScalingRequest added to public API in transformer_engine/pytorch/quantization.py
  • CustomRecipeState.make_quantizers() refactored from hardcoded string roles to role-based dispatch
  • _has_delayed_scaling_state() replaces recipe.delayed() checks throughout base.py and FP8GlobalStateManager to support custom recipes with mixed DS/stateless quantizers
  • Files renamed: quantization_nvfp4.pyquantization_ref_nvfp4.py, quantization_current_scaling.pyquantization_ref_current_scaling.py
  • New files: quantization_recipes_base.py (silicon quantizer factories mirroring built-in recipes), quantization_factory_examples.py (composite factory examples)
  • Two issues found in LayerNormMLP.get_quantizer_roles(): forward slots 2 and 3 carry identical roles (fc2 input), making GEMM1_OUTPUT indistinguishable from GEMM2_INPUT; backward slots 0 and 3 also carry identical roles (fc1 grad_output). Fine-grained per-slot factory dispatch cannot distinguish these pairs.
  • Misleading inline comments in quantization_recipes_base.py invert the meaning of force_pow_2_scales=False and amax_epsilon=0.0.

Confidence Score: 3/5

  • PR introduces a well-designed new API but contains role assignment bugs in LayerNormMLP that could silently misdirect custom factories for users targeting fine-grained per-slot quantization control.
  • The overall architecture is clean and the API design is solid. Linear, LayerNormLinear, GroupedLinear, BasicLinear, and DotProductAttention all implement get_quantizer_roles() correctly. However, LayerNormMLP.get_quantizer_roles() has a logic issue where forward slots 2 and 3 are identical (both labeled as fc2 input rather than fc1 output + fc2 input respectively), and backward slots 0 and 3 are identical. Since this is experimental API and the test coverage focuses on name-level dispatch rather than GEMM-slot-level dispatch, the issue would not be caught by the current tests. The misleading comments in quantization_recipes_base.py are a minor but clear error. The score is lowered from 4 to 3 due to the LayerNormMLP slot ambiguity.
  • transformer_engine/pytorch/module/layernorm_mlp.py (duplicate slot roles in get_quantizer_roles), transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py (misleading inline comments)

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Introduces QuantizerRole, QuantizerRequest, and DelayedScalingRequest dataclasses; refactors CustomRecipeState.make_quantizers() to use role-based dispatch instead of hardcoded string roles; adds _handle_delayed_scaling_requests() and _has_delayed_scaling_state() helpers for mixed stateful/stateless quantizers. Core logic looks correct.
transformer_engine/pytorch/module/layernorm_mlp.py Adds get_quantizer_roles() for LayerNormMLP. Forward slots 2 and 3 carry identical roles (both labeled as fc2 input), making GEMM1_OUTPUT indistinguishable from GEMM2_INPUT to a custom factory. Backward slots 0 and 3 are also identical. These duplicates limit fine-grained per-slot dispatch.
transformer_engine/pytorch/module/base.py Adds output_quantizer_role/grad_input_quantizer_role properties with invalidation logic, get_quantizer_roles() base implementation (returns None), and _warn_missing_output_quantizer_role(). Replaces recipe.delayed() checks with _has_delayed_scaling_state() for CustomRecipe DS support.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds get_quantizer_roles() for DotProductAttention with detailed slot-group documentation. Adds name parameter and custom recipe path to init_fp8_metadata/set_meta_tensor. Boundary slots (O output, dQKV grad-input) default to hint-only roles.
transformer_engine/pytorch/attention/multi_head_attention.py Adds _update_output_quantizer_roles() which wires boundary roles across QKV→DPA→Proj. Well-documented with ASCII diagram. Called in the forward path before layernorm_output computation.
transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py New file with factories mirroring built-in recipes. Contains two misleading inline comments: force_pow_2_scales=False annotated as "constrain scale to powers of 2" and amax_epsilon=0.0 annotated as "clamp amax from below to avoid div-by-zero" — both invert the parameter semantics.

Sequence Diagram

sequenceDiagram
    participant M as TE Module<br/>(Linear/LayerNormMLP/DPA)
    participant B as TransformerEngineBaseModule
    participant RS as RecipeState.create()
    participant CRS as CustomRecipeState
    participant F as qfactory(role)

    M->>B: forward() → init_fp8_metadata()
    B->>M: get_quantizer_roles(fwd, num_quantizers)
    M-->>B: List[QuantizerRole | None]
    B->>RS: create(recipe, mode, num_quantizers, roles=roles)
    RS->>CRS: __init__() + state.roles = roles
    B->>CRS: make_quantizers()
    loop for each slot i
        CRS->>F: qfactory(roles[i])
        F-->>CRS: Quantizer | DelayedScalingRequest | None
    end
    CRS->>CRS: _handle_delayed_scaling_requests()<br/>(allocate shared DS buffers if needed)
    CRS-->>B: List[Quantizer]
    B-->>M: quantizers ready for GEMM
Loading

Last reviewed commit: 0c1ec9b

@greptile-apps

This comment was marked as off-topic.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this design is quite clean and generalizable.

Comment on lines +1320 to +1329
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="output", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_input", name=name),
]
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"output" and "grad_input" roles don't make sense. In reality, we are implicitly assuming that the tensor will be consumed by another linear-like layer.

Suggested change
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="output", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_input", name=name),
]
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="input", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
]

Alternatively, if we want to use the output in FP8 DPA, the right role would be module_type="dpa" and module_type="input". We should probably make this configurable. I kind of like that this design is exposing the hidden assumptions we've been making.

Copy link
Collaborator Author

@negvet negvet Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about "output" and "grad_input" roles. Setting roles for those slots to None (the safest) and enabling the configuration. Also configured it in MHA.

Comment on lines +310 to +314
assert counts["input"] == 1
assert counts["weight"] == 1
assert counts["output"] == 1
assert counts["grad_output"] == 1
assert counts["grad_input"] == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert counts["input"] == 1
assert counts["weight"] == 1
assert counts["output"] == 1
assert counts["grad_output"] == 1
assert counts["grad_input"] == 1
assert counts["input"] == 2
assert counts["weight"] == 1
assert counts["output"] == 0
assert counts["grad_output"] == 2
assert counts["grad_input"] == 0

negvet and others added 2 commits February 20, 2026 14:31
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

negvet and others added 5 commits February 20, 2026 15:05
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Comment on lines +85 to +88
def is_gemm(self) -> bool:
"""Whether this role belongs to a GEMM-based module."""
return self.module_type in self.GEMM_MODULE_TYPES

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is baking in assumptions about what formats are similar (our recent experiences with grouped tensors makes me wonder if the requirements for "linear" and "grouped_linear" will diverge in the future), and it's also not giving us that much convenience.

Suggested change
def is_gemm(self) -> bool:
"""Whether this role belongs to a GEMM-based module."""
return self.module_type in self.GEMM_MODULE_TYPES

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, removed

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Evgeny and others added 2 commits February 25, 2026 16:43
Signed-off-by: Evgeny <etsykunov@gmail.com>
@negvet negvet requested review from ptrendx and timmoon10 February 25, 2026 16:45
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (2)

tests/pytorch/distributed/run_numerics_exact.py, line 63
Missing None check before accessing role.tensor_type. According to module implementations (Linear, LayerNormLinear, etc.), role can be None for output and grad_input quantizer slots when output_quantizer_role/grad_input_quantizer_role properties are not set. This will cause AttributeError.

        if role is None:
            return None
        if role.tensor_type == "input":

tests/pytorch/nvfp4/test_nvfp4_module_exact.py, line 83
Missing None check before accessing role.tensor_type. role can be None for output and grad_input quantizer slots. Add check at start of factory:

        if role is None:
            return None
        if role.tensor_type == "input":

negvet and others added 3 commits March 2, 2026 15:27
@negvet
Copy link
Collaborator Author

negvet commented Mar 2, 2026

About 1d63084

Custom recipe factories can now return stateful quantizer requests (not just stateless quantizers) — TE detects these request dataclasses, allocates the required shared state (scale/amax buffers), and replaces them with real quantizer instances wired into existing infrastructure. Factories cannot create stateful quantizers directly because the shared buffers must be allocated across all slots simultaneously and registered with TE's global state manager for cross-module distributed reduction, recompute, and checkpointing — a lifecycle that only TE can orchestrate. Delayed scaling is supported via DelayedScalingRequest; the composed DelayedScalingRecipeState integrates with distributed amax reduction, activation recompute, and checkpointing. Factories can mix stateful requests and stateless quantizers per-slot within the same CustomRecipe.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet
Copy link
Collaborator Author

negvet commented Mar 2, 2026

@cyanguwa @ptrendx, please review, enabled custom recipe for attention with current scaling, current envvar-based routing is still functional.

Comment on lines +823 to +831
Forward (3 GEMMs x 3 = 9 slots):
GEMM1 -> QKV (at ``GEMM1_OUTPUT``),
GEMM2 -> O (at ``GEMM2_INPUT``),
GEMM3 -> S (at ``GEMM3_OUTPUT``).

Backward (3 GEMMs x 2 = 6 slots):
GEMM1 -> dQKV (at ``GRAD_OUTPUT1``),
GEMM2 -> dO (at ``GRAD_INPUT2``),
GEMM3 -> dP (at ``GRAD_INPUT3``).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intuitive to me - O is created by the last gemm, whereas S is created by the first gemm inside DPA (so should be gemm 2?).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intuitive to me either, this is a legacy of the cudnn naming convention, trying to fix in f21ce2f

self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

def get_quantizer_roles(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand how to properly create this function. We should have some documentation here on what is expected - let's say you have some number of gemms, you want to assign potentially different quantizers to their inputs/weights and you would like to name those roles X/Y etc. - this is what you would need to have in it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved in 0c1ec9b

Comment on lines +121 to +122
This factory demonstrates how to use ``CustomRecipe`` with ``fp8_dpa=True``
to combine NVFP4 quantization for linear layers with FP8 attention.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I was hoping that something like this would enable getting rid of fp8_dpa/fp8_mha toggles or at least make them irrelevant for the custom recipe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_dpa/fp8_mha is a kernel-dispatch gate. The factory says how to quantize; fp8_dpa/fp8_mha says whether to run the FP8 attention kernel path at all. It is possible to infer the gate from the quantizer type produced by the factory. This is a refactor, but nothing dramatic. Still, I would propose to plan it as a follow up, not in this PR.

cc @cyanguwa

Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (3)

transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py, line 1550
Contradictory inline comment

The comment # constrain scale to powers of 2 describes the exact opposite of the parameter value — False means scales are NOT constrained to powers of 2. A reader will likely infer the wrong behaviour from this comment. It should say something like # do not constrain scale to powers of 2 or mirror the adjacent parameter.

        force_pow_2_scales=False,  # do not constrain scale to powers of 2

transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py, line 1471
nvfp4_linear_fp8_dpa_factory silently produces NVFP4 quantizers for DPA boundary slots when fp8_mha=True

When MultiheadAttention is configured with fp8_mha=True, _update_output_quantizer_roles sets core_attention.output_quantizer_role to a proper QuantizerRole(module_type="linear", tensor_type="input", name=proj_name). With that role, the is_dpa_boundary guard here is never triggered (because role.module_type is "linear", not empty), so the call falls through to _make_nvfp4_quantizer(role). This returns an NVFP4Quantizer, which then fails the assert isinstance(_q, _fp8_types) assertion inside get_attention_quantizers at runtime.

The factory is designed for the standalone fp8_dpa=True / fp8_mha=False scenario where DPA emits hint-only roles (empty module_type), but this is not clearly stated in the docstring. Consider adding an explicit fp8_mha=True guard or at least a note in the docstring warning users that this factory is incompatible with fp8_mha=True:

# DPA boundary slots (O output / dQKV grad-input): the fused attention
# kernel only supports FP8 quantizers here, regardless of the linear recipe.
# NOTE: when fp8_mha=True, MultiheadAttention wires output_quantizer_role with
# module_type="linear", which bypasses this guard.  This factory is only
# designed for standalone fp8_dpa=True usage (fp8_mha=False).
is_dpa_boundary = (
    role is not None
    and (
        # standalone DPA: hint-only role emitted by DotProductAttention
        (not role.module_type and ("dpa_output" in role.name or "dpa_grad_input" in role.name))
        # MHA-wired DPA: output_quantizer_role set by _update_output_quantizer_roles
        or (role.module_type == "linear" and role.tensor_type == "input")
    )
)

transformer_engine/pytorch/module/layernorm_mlp.py, line 2018
Forward slot 2 and slot 3 carry identical roles, making them indistinguishable to factories

Slot 2 (GEMM1_OUTPUT) is the output of FC1, and slot 3 (GEMM2_INPUT) is the input quantizer for FC2. Both are assigned QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name).

Because the roles are identical, a custom factory cannot distinguish between the two quantizer slots for this internal activation. This mirrors the forward design choice in other modules (where GEMM1_OUTPUT and GEMM2_INPUT represent the same logical tensor), but it differs from the output-boundary convention: in Linear and LayerNormLinear, the analogous slot uses tensor_type="output" for the GEMM output position.

Consider using tensor_type="output" for slot 2 to allow factories to differentiate:

if fwd:
    base = [
        QuantizerRole(module_type="linear", tensor_type="input",  name=fc1_name),   # GEMM1_INPUT
        QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name),   # GEMM1_WEIGHT
        QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name),   # GEMM1_OUTPUT (fc1→fc2 activation)
        QuantizerRole(module_type="linear", tensor_type="input",  name=fc2_name),   # GEMM2_INPUT
        QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name),   # GEMM2_WEIGHT
        self._output_quantizer_role,                                                 # GEMM2_OUTPUT
    ]

negvet and others added 2 commits March 4, 2026 16:03
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

tests/pytorch/test_custom_recipe.py, line 104
("grad_output") is a parenthesized string, not a tuple. The in operator performs substring matching rather than membership testing. This happens to work for the current set of tensor_type values, but is incorrect Python and could silently fail if future tensor_type values contain "grad_output" as a substring.

Use a proper equality check instead:

        if role.tensor_type == "grad_output":

Or alternatively role.tensor_type in ("grad_output",) with a trailing comma to create a real single-element tuple. This pattern appears in multiple test functions: lines 104, 141, 203, and 260.


transformer_engine/pytorch/module/layernorm_mlp.py, line 2008
Slots 2 and 3 are assigned identical QuantizerRole objects, making them indistinguishable to factory dispatch:

slot 2 (GEMM1_OUTPUT): QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name)
slot 3 (GEMM2_INPUT):  QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name)

These represent physically different tensors (output of fc1 before GeLU vs. input to fc2 after GeLU) and should have distinct roles so factories can apply different quantization settings per slot.

Slot 2 should identify as the fc1 output:

QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name),  # slot 2: GEMM1 output
QuantizerRole(module_type="linear", tensor_type="input",  name=fc2_name),  # slot 3: GEMM2 input

This allows factories to inspect role.name and role.tensor_type independently for fine-grained per-slot quantization control.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (4)

tests/pytorch/nvfp4/test_nvfp4_module_exact.py, line 116
AttributeError when factory is called with role=None

Since the Linear/LayerNormLinear/etc. output and grad-input slots now default to None (when _output_quantizer_role / _grad_input_quantizer_role are not set), qfactory(None) will be called for those slots.

The inner factory starts with if role.tensor_type == "input": without first checking role is None, which will immediately raise AttributeError: 'NoneType' object has no attribute 'tensor_type' when run through a CustomRecipe.

The same crash pattern exists in tests/pytorch/distributed/run_numerics_exact.py at the equivalent factory.

    def factory(role):
        if role is None:
            return None
        if role.tensor_type == "input":

transformer_engine/pytorch/module/layernorm_mlp.py, line 2153
Duplicate backward role for slot 3 (fc2 grad-input / fc1 output-grad)

Slot 3 (the fc2 GEMM's grad-input, which is the internal boundary gradient flowing back through the activation into fc1's output) is assigned the same role as slot 0:

# slot 0: fc1 grad_output
QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name),
...
# slot 3: fc2 grad_input (internal boundary) — but labeled identically to slot 0
QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name),

A user factory dispatching on role.name and role.tensor_type has no way to distinguish the actual fc1 grad_output (slot 0, flowing in from upstream) from the internal fc2→fc1 boundary gradient (slot 3). In fine-grained quantization scenarios (e.g. the test_custom_recipe_quantization_targets example), this could cause the wrong quantizer type to be selected for the internal boundary.

Consider using a distinct role for slot 3, such as:

QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name),

or a dedicated tensor_type like "internal_grad" to make the slot semantically distinguishable.


transformer_engine/pytorch/module/layernorm_mlp.py, line 2145
Forward slots 2 and 3 share identical roles

Slot 2 is the output of fc1 (the internal boundary feeding into the activation before fc2) and slot 3 is fc2's actual input quantizer — yet both are assigned:

QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name)

While the docstring notes "Internal boundaries use fixed roles with known consumer identity" (so slot 2's consumer-perspective labelling is intentional), the practical effect is that the factory receives two indistinguishable roles. For example, a factory that applies a special quantization for fc2's input will apply the same quantizer to both the fc1 output boundary and the fc2 GEMM input, which may not match user intent when fine-tuning quantization per-slot.

Consider naming the boundary slot differently, e.g. with a tensor_type of "output" for fc1 and "input" for fc2:

QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name),   # slot 2
QuantizerRole(module_type="linear", tensor_type="input",  name=fc2_name),   # slot 3

transformer_engine/pytorch/quantization.py, line 2609
amax_compute_algo equality check may fail for callable objects

When validating that all DelayedScalingRequest instances in a state share identical parameters, the code uses if v0 != vi for the amax_compute_algo field, which can be a Callable. Comparing two distinct callable objects with != in Python compares by identity, not by value — so two factories that are functionally identical but different objects would fail this check even if the user passes the same function from a different import path or via a lambda. Meanwhile scaling_factor_compute_algo has the same problem.

This could produce a confusing ValueError in valid usage (e.g. two layers created from the same factory function with a fresh closure each time). Consider documenting this limitation prominently, or restricting the equality check to non-callable fields.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (3)

transformer_engine/pytorch/module/layernorm_mlp.py, line 2008
Duplicate forward roles for GEMM1_OUTPUT and GEMM2_INPUT

Slots 2 and 3 in the forward base list are identical: both are QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name). Slot 2 corresponds to GEMM1_OUTPUT (fc1's output tensor) and slot 3 to GEMM2_INPUT (fc2's input tensor). These are assigned the same role, making them indistinguishable to a custom factory.

For LayerNormMLP with num_gemms=2, the forward slot layout is:

Slot 0 → GEMM1_INPUT   (fc1 input)
Slot 1 → GEMM1_WEIGHT  (fc1 weight)
Slot 2 → GEMM1_OUTPUT  (fc1 output / intermediate activation)  ← currently labeled as fc2 input
Slot 3 → GEMM2_INPUT   (fc2 input)                             ← labeled as fc2 input (duplicate)
Slot 4 → GEMM2_WEIGHT  (fc2 weight)
Slot 5 → GEMM2_OUTPUT  (module output)

Slot 2 should distinguish the fc1 output from the fc2 input. If the intent is to label the internal fc1→fc2 boundary from the consumer perspective, slot 2 should at minimum carry a different tensor_type (e.g. "output") or a different name:

                QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name),
                QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name),

transformer_engine/pytorch/module/layernorm_mlp.py, line 2017
Duplicate backward roles for GEMM1_GRAD_OUTPUT and GEMM2_GRAD_INPUT

Slots 0 and 3 in the backward base list are identical: both are QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name). Slot 0 is GEMM1_GRAD_OUTPUT (the gradient flowing into fc1's backward GEMM) and slot 3 is GEMM2_GRAD_INPUT (the internal boundary — the gradient of the intermediate activation between fc1 and fc2, which conceptually is also fc1's grad_output).

While conceptually the same tensor (both represent the gradient at the fc1↔fc2 boundary), having two slots with identical roles means a factory cannot distinguish GEMM1_GRAD_OUTPUT from GEMM2_GRAD_INPUT. A factory doing fine-grained slot-level dispatch on (tensor_type, name) will receive two calls with the same arguments.

The slot at index 3 should ideally carry a distinct role. If this is intentional (both slots logically represent the same tensor), a comment to that effect would help future maintainers understand why the duplication is deliberate rather than a copy-paste error.


transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py, line 74
Misleading inline comments invert the parameter semantics

Both inline comments describe the opposite effect of the value they annotate:

  • force_pow_2_scales=False, # constrain scale to powers of 2 — Setting this to False explicitly does not constrain scales to powers of 2. The comment implies it does.
  • amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero — An epsilon of 0.0 provides no lower-bound clamping and does nothing to prevent division by zero. A positive epsilon (e.g. 1e-12) would serve that purpose.

The same misleading comment appears on amax_epsilon=0.0 in float8_block_scaling_quantizer_factory at line 117.

        force_pow_2_scales=False,  # allow non-power-of-2 scales (matches Float8CurrentScaling defaults)
        amax_epsilon=0.0,  # no lower-bound clamping on amax (matches Float8CurrentScaling defaults)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants