Conversation
…ipe state Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces Key changes:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant M as TE Module<br/>(Linear/LayerNormMLP/DPA)
participant B as TransformerEngineBaseModule
participant RS as RecipeState.create()
participant CRS as CustomRecipeState
participant F as qfactory(role)
M->>B: forward() → init_fp8_metadata()
B->>M: get_quantizer_roles(fwd, num_quantizers)
M-->>B: List[QuantizerRole | None]
B->>RS: create(recipe, mode, num_quantizers, roles=roles)
RS->>CRS: __init__() + state.roles = roles
B->>CRS: make_quantizers()
loop for each slot i
CRS->>F: qfactory(roles[i])
F-->>CRS: Quantizer | DelayedScalingRequest | None
end
CRS->>CRS: _handle_delayed_scaling_requests()<br/>(allocate shared DS buffers if needed)
CRS-->>B: List[Quantizer]
B-->>M: quantizers ready for GEMM
Last reviewed commit: 0c1ec9b |
This comment was marked as off-topic.
This comment was marked as off-topic.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
timmoon10
left a comment
There was a problem hiding this comment.
Overall this design is quite clean and generalizable.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py
Outdated
Show resolved
Hide resolved
| base = [ | ||
| QuantizerRole(module_type="linear", tensor_type="input", name=name), | ||
| QuantizerRole(module_type="linear", tensor_type="weight", name=name), | ||
| QuantizerRole(module_type="linear", tensor_type="output", name=name), | ||
| ] | ||
| else: | ||
| base = [ | ||
| QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), | ||
| QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), | ||
| ] |
There was a problem hiding this comment.
"output" and "grad_input" roles don't make sense. In reality, we are implicitly assuming that the tensor will be consumed by another linear-like layer.
| base = [ | |
| QuantizerRole(module_type="linear", tensor_type="input", name=name), | |
| QuantizerRole(module_type="linear", tensor_type="weight", name=name), | |
| QuantizerRole(module_type="linear", tensor_type="output", name=name), | |
| ] | |
| else: | |
| base = [ | |
| QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), | |
| QuantizerRole(module_type="linear", tensor_type="grad_input", name=name), | |
| ] | |
| base = [ | |
| QuantizerRole(module_type="linear", tensor_type="input", name=name), | |
| QuantizerRole(module_type="linear", tensor_type="weight", name=name), | |
| QuantizerRole(module_type="linear", tensor_type="input", name=name), | |
| ] | |
| else: | |
| base = [ | |
| QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), | |
| QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), | |
| ] |
Alternatively, if we want to use the output in FP8 DPA, the right role would be module_type="dpa" and module_type="input". We should probably make this configurable. I kind of like that this design is exposing the hidden assumptions we've been making.
There was a problem hiding this comment.
I agree about "output" and "grad_input" roles. Setting roles for those slots to None (the safest) and enabling the configuration. Also configured it in MHA.
tests/pytorch/test_custom_recipe.py
Outdated
| assert counts["input"] == 1 | ||
| assert counts["weight"] == 1 | ||
| assert counts["output"] == 1 | ||
| assert counts["grad_output"] == 1 | ||
| assert counts["grad_input"] == 1 |
There was a problem hiding this comment.
| assert counts["input"] == 1 | |
| assert counts["weight"] == 1 | |
| assert counts["output"] == 1 | |
| assert counts["grad_output"] == 1 | |
| assert counts["grad_input"] == 1 | |
| assert counts["input"] == 2 | |
| assert counts["weight"] == 1 | |
| assert counts["output"] == 0 | |
| assert counts["grad_output"] == 2 | |
| assert counts["grad_input"] == 0 |
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| def is_gemm(self) -> bool: | ||
| """Whether this role belongs to a GEMM-based module.""" | ||
| return self.module_type in self.GEMM_MODULE_TYPES | ||
|
|
There was a problem hiding this comment.
I think this is baking in assumptions about what formats are similar (our recent experiences with grouped tensors makes me wonder if the requirements for "linear" and "grouped_linear" will diverge in the future), and it's also not giving us that much convenience.
| def is_gemm(self) -> bool: | |
| """Whether this role belongs to a GEMM-based module.""" | |
| return self.module_type in self.GEMM_MODULE_TYPES |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
|
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
About 1d63084 Custom recipe factories can now return stateful quantizer requests (not just stateless quantizers) — TE detects these request dataclasses, allocates the required shared state (scale/amax buffers), and replaces them with real quantizer instances wired into existing infrastructure. Factories cannot create stateful quantizers directly because the shared buffers must be allocated across all slots simultaneously and registered with TE's global state manager for cross-module distributed reduction, recompute, and checkpointing — a lifecycle that only TE can orchestrate. Delayed scaling is supported via |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| Forward (3 GEMMs x 3 = 9 slots): | ||
| GEMM1 -> QKV (at ``GEMM1_OUTPUT``), | ||
| GEMM2 -> O (at ``GEMM2_INPUT``), | ||
| GEMM3 -> S (at ``GEMM3_OUTPUT``). | ||
|
|
||
| Backward (3 GEMMs x 2 = 6 slots): | ||
| GEMM1 -> dQKV (at ``GRAD_OUTPUT1``), | ||
| GEMM2 -> dO (at ``GRAD_INPUT2``), | ||
| GEMM3 -> dP (at ``GRAD_INPUT3``). |
There was a problem hiding this comment.
This is not intuitive to me - O is created by the last gemm, whereas S is created by the first gemm inside DPA (so should be gemm 2?).
There was a problem hiding this comment.
This is not intuitive to me either, this is a legacy of the cudnn naming convention, trying to fix in f21ce2f
| self.fp8_meta[fp8_meta_tensor_key] = recipe_state | ||
| self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() | ||
|
|
||
| def get_quantizer_roles( |
There was a problem hiding this comment.
I don't really understand how to properly create this function. We should have some documentation here on what is expected - let's say you have some number of gemms, you want to assign potentially different quantizers to their inputs/weights and you would like to name those roles X/Y etc. - this is what you would need to have in it.
| This factory demonstrates how to use ``CustomRecipe`` with ``fp8_dpa=True`` | ||
| to combine NVFP4 quantization for linear layers with FP8 attention. |
There was a problem hiding this comment.
To be honest I was hoping that something like this would enable getting rid of fp8_dpa/fp8_mha toggles or at least make them irrelevant for the custom recipe.
There was a problem hiding this comment.
fp8_dpa/fp8_mha is a kernel-dispatch gate. The factory says how to quantize; fp8_dpa/fp8_mha says whether to run the FP8 attention kernel path at all. It is possible to infer the gate from the quantizer type produced by the factory. This is a refactor, but nothing dramatic. Still, I would propose to plan it as a follow up, not in this PR.
cc @cyanguwa
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
Additional Comments (3)
The comment
When The factory is designed for the standalone # DPA boundary slots (O output / dQKV grad-input): the fused attention
# kernel only supports FP8 quantizers here, regardless of the linear recipe.
# NOTE: when fp8_mha=True, MultiheadAttention wires output_quantizer_role with
# module_type="linear", which bypasses this guard. This factory is only
# designed for standalone fp8_dpa=True usage (fp8_mha=False).
is_dpa_boundary = (
role is not None
and (
# standalone DPA: hint-only role emitted by DotProductAttention
(not role.module_type and ("dpa_output" in role.name or "dpa_grad_input" in role.name))
# MHA-wired DPA: output_quantizer_role set by _update_output_quantizer_roles
or (role.module_type == "linear" and role.tensor_type == "input")
)
)
Slot 2 ( Because the roles are identical, a custom factory cannot distinguish between the two quantizer slots for this internal activation. This mirrors the forward design choice in other modules (where GEMM1_OUTPUT and GEMM2_INPUT represent the same logical tensor), but it differs from the output-boundary convention: in Consider using if fwd:
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=fc1_name), # GEMM1_INPUT
QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name), # GEMM1_WEIGHT
QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name), # GEMM1_OUTPUT (fc1→fc2 activation)
QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), # GEMM2_INPUT
QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name), # GEMM2_WEIGHT
self._output_quantizer_role, # GEMM2_OUTPUT
] |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
Use a proper equality check instead: Or alternatively
These represent physically different tensors (output of fc1 before GeLU vs. input to fc2 after GeLU) and should have distinct roles so factories can apply different quantization settings per slot. Slot 2 should identify as the fc1 output: This allows factories to inspect |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Additional Comments (4)
Since the The inner factory starts with The same crash pattern exists in
Slot 3 (the fc2 GEMM's grad-input, which is the internal boundary gradient flowing back through the activation into fc1's output) is assigned the same role as slot 0: # slot 0: fc1 grad_output
QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name),
...
# slot 3: fc2 grad_input (internal boundary) — but labeled identically to slot 0
QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name),A user factory dispatching on Consider using a distinct role for slot 3, such as: QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name),or a dedicated
Slot 2 is the output of fc1 (the internal boundary feeding into the activation before fc2) and slot 3 is fc2's actual input quantizer — yet both are assigned: QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name)While the docstring notes "Internal boundaries use fixed roles with known consumer identity" (so slot 2's consumer-perspective labelling is intentional), the practical effect is that the factory receives two indistinguishable roles. For example, a factory that applies a special quantization for fc2's input will apply the same quantizer to both the fc1 output boundary and the fc2 GEMM input, which may not match user intent when fine-tuning quantization per-slot. Consider naming the boundary slot differently, e.g. with a QuantizerRole(module_type="linear", tensor_type="output", name=fc1_name), # slot 2
QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), # slot 3
When validating that all This could produce a confusing |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Additional Comments (3)
Slots 2 and 3 in the forward For Slot 2 should distinguish the fc1 output from the fc2 input. If the intent is to label the internal fc1→fc2 boundary from the consumer perspective, slot 2 should at minimum carry a different
Slots 0 and 3 in the backward While conceptually the same tensor (both represent the gradient at the fc1↔fc2 boundary), having two slots with identical roles means a factory cannot distinguish The slot at index 3 should ideally carry a distinct role. If this is intentional (both slots logically represent the same tensor), a comment to that effect would help future maintainers understand why the duplication is deliberate rather than a copy-paste error.
Both inline comments describe the opposite effect of the value they annotate:
The same misleading comment appears on |
Description
Introducing
QuantizerRoleThis is an API that allows to go down to "set this
LayerNormLinearin this transformer layer to be less aggressively quantized." (fine-grained, per-module/per-tensor quantization control mechanism)See
test_custom_recipe.py::test_custom_recipe_quantization_targets().Quantizer factory uses roles to dispatch according to its needs.
TE module/op emits a list of
QuantizerRole:Linear,LayerNormLinear,LayerNormMLPemitmodule_type="linear"withtensor_typein{"input", "weight", "grad_output"}.GroupedLinearemitsmodule_type="grouped_linear".CustomRecipeaccepts aqfactorycallable that receivesQuantizerRoleand returns a quantizer.Factories can be composed - e.g., dispatch (to different sub-factories as an option) based on
module_type(dpavslinear) and then refine based ontensor_type.Summary:
get_quantizer_roles()that returns a list of QuantizerRole objects.set_meta_tensor(), modules callget_quantizer_roles()and pass roles toRecipeState.create().RecipeState.create()assigns roles to the state (e.g.,CustomRecipeState.roles).CustomRecipeState.make_quantizers()callsqfactory(role)for each role to create quantizers.role.module_type,role.tensor_type, androle.nameto dispatch to different quantizers.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: