Skip to content

[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel#2555

Open
zhongbozhu wants to merge 22 commits intoNVIDIA:mainfrom
zhongbozhu:zhongbo/dense_row_col_rht_fp4_quant
Open

[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel#2555
zhongbozhu wants to merge 22 commits intoNVIDIA:mainfrom
zhongbozhu:zhongbo/dense_row_col_rht_fp4_quant

Conversation

@zhongbozhu
Copy link
Collaborator

@zhongbozhu zhongbozhu commented Jan 3, 2026

Description

Note: #2558 reported a bug in #2411. Fix is here #2564: make sure you cherry-pick this one too before it's in main.

Previously, similar optimization has been applied for MOE grouped quantize with RHT in #2411. This PR targets the dense linear layers & shared experts when being quantized to NVFP4. Having this fusion means high precision input only needs to be read once while without this fusion, it needs to be read twice.

Similarly, we have env var NVTE_USE_FAST_MATH to control the numerical behavior of RHT quant fusion kernel to accelerate it further. The fast math is only applied to the high precision math so it will have minimal impact of the training convergence.

What fast-math toggle controls:

  1. replace x / y by x * (1/y)
  2. replace 1 / x by reciporal_approximate_ftz(x)
  3. when RHT cast fusion is available, fusion allows nvfp4 quantize to be performed directly on FP32 data in register files, this will essentially remove a round trip between FP32 to BF16 then FP32.

Therefore, I DO recommend turn it on since it will significantly improve the RHT kernel performance.

The only reason why it's still not default open is because we want ZERO TOLERNACE test between our CUDA quantize kernels and our pytorch-based emulated quantize references. With fast math toggle turned on, it's hard to pass test with zero tolerance without further investigation of how to relax the test conditions while still providing high confidence of the test case.

TODO items:

  • Merge the bug fix PR 2564 first.
  • Some cutlass deprecating APIs are being used, output many warnings.
  • Maybe turn on fast math by default and use NVTE_DISABLE_RHT_FAST_MATH instead of using NVTE_USE_FAST_MATH? @timmoon10 for opinions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zhongbozhu zhongbozhu requested a review from timmoon10 January 3, 2026 01:23
@zhongbozhu zhongbozhu self-assigned this Jan 3, 2026
@zhongbozhu zhongbozhu added the fp4 label Jan 3, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 3, 2026

Greptile Summary

This PR integrates a new Blackwell-native CUTLASS UMMA fusion kernel (row_cast_col_hadamard_transform_cast_fusion.cu) for dense linear layers and shared experts under NVFP4, completing an optimization that was previously only available for MoE grouped quantization (PR #2411). The new kernel reads the high-precision BF16 input only once and performs row-wise quantization, the 16×16 Random Hadamard Transform (RHT), and column-wise quantization + transpose in a single GPU pass. An optional NVTE_USE_FAST_MATH env-var enables further approximation (reciprocal, FTZ) of the high-precision math at the cost of exactness.

Key changes:

  • New dense fusion kernel row_cast_col_hadamard_transform_cast_fusion.cu (~1400 lines) with Blackwell arch-guard, NVTE_CHECK dimension validation, and template parameters for stochastic rounding / fast math / row-only / col-only modes.
  • Refactored NVFP4Quantizer::quantize_impl: fusion path calls nvte_quantize_with_hadamard_transform; unfused fallback extracted into quantize_with_rht_unfused_helper.
  • Correctness improvement: need_separate_columnwise_rng now additionally requires !eligible_for_rht_cast_fusion, fixing unnecessary dual-RNG-state allocation in the fused path. Missing set_nvfp4_2d_quantization on quant_config_columnwise also added.
  • Tests extended to cover columnwise_only, rowwise_only, and both_directions quantization modes.
  • Architecture guard added to hadamard_transform_cast_fusion.cu for the existing rht_gemm_device kernel.
  • Known outstanding item: Cutlass deprecated API warnings are still present (noted in PR TODO).

Confidence Score: 3/5

  • PR is functionally correct for common training shapes but introduces a performance regression in the unfused path and has a pending bug-fix dependency (PR [NVFP4][MOE] Bug Fix for NVFP4 Grouped Quant #2564).
  • The fused-path logic, RNG-state handling, and amax computation look correct. The test coverage is solid, including non-fusion-eligible shapes. Two concerns lower confidence: (1) a concrete regression where rht_output_t is allocated unconditionally in the unfused path regardless of columnwise_usage, wasting a cols×rows BF16 buffer for every odd-shaped rowwise-only RHT input; and (2) the PR description explicitly notes a dependency on PR [NVFP4][MOE] Bug Fix for NVFP4 Grouped Quant #2564 (a separate bug fix) that must be cherry-picked. The Cutlass deprecation warnings are outstanding but cosmetic.
  • transformer_engine/pytorch/csrc/quantizer.cpp — the unfused-path rht_output_t allocation regression is here.

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu New dense-linear fusion kernel implementing Row-wise quantization + RHT + Columnwise quantization in a single Blackwell UMMA/TMA pass; correctly uses NVTE_CHECK for dimension validation and properly guards architecture-specific code with if constexpr. Swizzle SF output is stubbed out as false pending a future TODO.
transformer_engine/pytorch/csrc/quantizer.cpp Refactors NVFP4Quantizer to use new fused dense-kernel when eligible; introduces quantize_with_rht_unfused_helper for cleaner fallback path. Contains a regression: rht_output_t buffer is allocated unconditionally in the unfused path even when columnwise_usage=false, wasting a cols×rows GPU allocation for every odd-shaped rowwise-only RHT input.
transformer_engine/common/include/transformer_engine/hadamard_transform.h Adds declaration for nvte_quantize_with_hadamard_transform and deprecates nvte_hadamard_transform_cast_fusion_columnwise; clean, minimal change.
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu Adds Blackwell-architecture guard (if constexpr (!is_blackwell_arch)) to rht_gemm_device kernel; the structural change (moving original body into else branch) is balanced and correct.
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py Extends tests to cover columnwise_only, rowwise_only, and both_directions modes, and properly guards assertions behind if return_rowwise; test shapes include non-fusion-eligible dimensions (e.g. 256×272) that exercise the unfused fallback path.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[NVFP4Quantizer::quantize_impl] --> B{with_rht?}
    B -- No --> C[nvte_quantize_v2\nrow-wise only]
    B -- Yes --> D{eligible_for_rht_cast_fusion?\nbf16 input AND rows%64==0 AND cols%128==0}
    D -- Yes Fused path --> E[nvte_quantize_with_hadamard_transform\nSingle Blackwell UMMA kernel:\n1. Row-wise quantize input\n2. Apply 16x16 RHT via GEMM\n3. Col-wise quantize + transpose]
    D -- No Unfused path --> F[allocate rht_output_t buffer\ncols x rows BF16]
    F --> G[quantize_with_rht_unfused_helper]
    G --> H{rowwise_usage?}
    G --> I{columnwise_usage?}
    H -- Yes --> J[nvte_quantize_v2\nrow-wise quantize input directly]
    I -- Yes --> K[nvte_hadamard_transform\ncompute RHT of transposed input]
    K --> L[nvte_quantize_v2\nquantize RHT output]
    E --> M[Output: rowwise FP4 + scale_inv\nAND/OR columnwise FP4 + scale_inv]
    J --> M
    L --> M
Loading

Last reviewed commit: 999fe85

@zhongbozhu zhongbozhu changed the title [NVFP4][Dense] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose Fusion Kernel [NVFP4][Dense] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel Jan 3, 2026
@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from c80932f to fc42825 Compare January 3, 2026 04:16
@zhongbozhu
Copy link
Collaborator Author

/te-ci arm L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. benchmarks/linear/benchmark_linear.py, line 141 (link)

    logic: NVTX range is pushed but never popped in the benchmark function

  2. transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu, line 346 (link)

    syntax: Typo in comment: 'SMEMork' should be 'SMEM work'

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@zhongbozhu zhongbozhu added the MoE label Jan 6, 2026
@zhongbozhu zhongbozhu changed the title [NVFP4][Dense] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel [NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel, Fixing NVFP4 Group Quant Bug Jan 6, 2026
@zhongbozhu zhongbozhu changed the title [NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel, Fixing NVFP4 Group Quant Bug [NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel Jan 6, 2026
@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from 2bc695e to 6ea9dab Compare January 9, 2026 23:14
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR integrates a Cutlass-based fusion kernel that combines row-wise quantization and column-wise RHT (Random Hadamard Transform) + quantization + transpose operations for NVFP4 dense linear layers and shared experts. The key optimization reduces memory bandwidth by reading high-precision input data once instead of twice.

Key Changes

New Fusion Kernel (row_cast_col_hadamard_transform_cast_fusion.cu):

  • Implements nvte_hadamard_transform_cast_fusion API that performs both rowwise and columnwise quantization in a single pass
  • Uses MMA hardware for efficient Hadamard transform computation
  • Eligible when input is BF16 with dimensions divisible by 64×128
  • Reads pre-computed amax values to calculate FP8 scaling factors
  • Supports stochastic rounding and fast math optimization flags

Refactored Quantizer Logic (quantizer.cpp):

  • Moved unfused RHT path into quantize_with_rht_unfused_helper method for cleaner code organization
  • Improved RNG state handling: single RNG state when fusion is used, separate states for rowwise/columnwise when unfused
  • Added NVTE_USE_FAST_MATH environment variable support for accelerating high-precision math operations
  • Eligibility check moved before RNG state generation to avoid unnecessary work

Extended Test Coverage (test_nvfp4_rht_quantize_exact.py):

  • Added "columnwise-only" quantization mode testing alongside existing "quantize" and "quantize_transpose" modes
  • Tests now validate rowwise/columnwise results conditionally based on the quantization mode

Grouped Quantization Support (cast.cpp):

  • Split-quantize path now uses fused kernel when all tensors have 128-aligned dimensions
  • Bulk RNG state generation for grouped kernels (single state shared across splits)
  • Fast math flag propagation to all quantization configs

Architecture Notes

The fusion provides optimal performance when:

  1. Input dtype is BF16
  2. Rows are divisible by 64 (MMA tile requirement)
  3. Columns are divisible by 128 (MMA tile requirement)

When these conditions aren't met, the code gracefully falls back to the unfused path with separate kernel launches for rowwise and columnwise quantization.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk after addressing documentation and TODO items mentioned in the PR description
  • Score of 4 reflects a well-engineered feature with thorough implementation. The code demonstrates good software practices: clean refactoring with extracted helper methods, proper error handling, graceful fallback paths, and comprehensive test coverage including the new columnwise-only mode. The fusion kernel follows established patterns from the grouped quantization PR #2411. Deducted 1 point due to: (1) PR author notes cutlass deprecation warnings need addressing, (2) TODOs remain about potentially defaulting fast math on, and (3) the ~1400 line CUDA kernel file has limited inline documentation for complex template logic
  • The main CUDA kernel file (row_cast_col_hadamard_transform_cast_fusion.cu) would benefit from additional inline comments explaining the template parameter switches and MMA computation flow, but no files have critical issues requiring immediate attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/csrc/quantizer.cpp 4/5 Refactored NVFP4 quantize_impl to use new fused RHT cast kernel, extracted unfused helper, improved RNG state handling for fused vs unfused paths
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu 4/5 New CUDA kernel implementing fused row-cast and column-RHT-transpose-cast using Cutlass MMA hardware for BF16 inputs with 64x128 alignment
transformer_engine/common/include/transformer_engine/hadamard_transform.h 5/5 Added new API function nvte_hadamard_transform_cast_fusion for dense layer quantization, marked old columnwise function for future deprecation
transformer_engine/pytorch/csrc/extensions/cast.cpp 4/5 Added NVTE_USE_FAST_MATH env var support in split_quantize for grouped NVFP4 kernels, improved RNG state setup with bulk generation flag
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py 5/5 Extended test coverage to support columnwise-only quantization mode, added return_identity parameter to test all three modes

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Quantizer as NVFP4Quantizer
    participant API as nvte_hadamard_transform_cast_fusion
    participant Kernel as row_col_rht_gemm_ntt_w_sfc
    participant AmaxKernel as nvte_hadamard_transform_amax
    
    User->>Quantizer: quantize(input, output)
    Quantizer->>Quantizer: Check eligibility (BF16, rows%64==0, cols%128==0)
    
    alt With RHT and eligible for fusion
        Quantizer->>AmaxKernel: Compute rowwise & columnwise amax
        AmaxKernel-->>Quantizer: amax values populated
        
        alt Stochastic rounding enabled
            Quantizer->>Quantizer: Generate RNG state
        end
        
        alt Fast math enabled (NVTE_USE_FAST_MATH)
            Quantizer->>Quantizer: Set use_fast_math flag
        end
        
        Quantizer->>API: Call with input, output, hadamard_matrix, quant_config
        API->>Kernel: Launch fused kernel
        
        Kernel->>Kernel: Read amax values
        Kernel->>Kernel: Perform rowwise quantization to FP4
        Kernel->>Kernel: Compute RHT using MMA hardware
        Kernel->>Kernel: Transpose and quantize to FP4
        Kernel->>Kernel: Write FP8 scales
        
        Kernel-->>API: Complete
        API-->>Quantizer: Return
        
    else Not eligible for fusion
        Quantizer->>AmaxKernel: Compute amax
        AmaxKernel-->>Quantizer: amax values
        
        alt Rowwise usage
            Quantizer->>Quantizer: Call nvte_quantize_v2 for rowwise
        end
        
        alt Columnwise usage
            Quantizer->>Quantizer: Call nvte_hadamard_transform for RHT
            Quantizer->>Quantizer: Call nvte_quantize_v2 for columnwise
        end
    end
    
    Quantizer-->>User: Quantized output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Additional Comments (2)

transformer_engine/pytorch/csrc/extensions/cast.cpp
The alignment check split_section % 128 == 0 for fusion eligibility should be documented or extracted as a constant. This magic number appears in multiple places (quantizer.cpp line 1552 uses different values) and the inconsistency suggests potential bugs. Consider defining constexpr size_t NVFP4_FUSION_ROW_ALIGNMENT = 128; to centralize this requirement.


transformer_engine/pytorch/csrc/quantizer.cpp
When need_separate_columnwise_rng is true, a new RNG state is generated for columnwise quantization. However, the philox generator state is advanced for each call, which means the random sequences will differ between runs if the fusion eligibility changes (e.g., due to shape variations). For reproducibility across different code paths, consider using a deterministic offset based on tensor properties rather than sequential generator advancement.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch 3 times, most recently from 011169e to 39f9272 Compare January 29, 2026 00:08
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1644 to +1648
// 1. Rowwise quantization
// 2. RHT followed by columnwise quantization & transpose
NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(),
quant_config, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider documenting the performance impact of NVTE_USE_FAST_MATH

Since the PR description strongly recommends enabling fast math for significant performance improvement, consider adding a comment here explaining the expected performance gain and why it's recommended for production use (currently only noted in the PR description).

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +379 to +385
static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0;
static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0;
static int constexpr NumMmaThreadCount = kEnableRHTColQuant? 32: 0;
static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant? 1: 0;
static int constexpr NumSchedThreads = 32;
static int constexpr NumMainloopLoadThreads = 32;
static int constexpr NumEpilogueThreads = NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding rationale for thread counts

Adding a brief comment explaining why these specific thread counts (32 MMA, 128 col quant, 256 row quant) were chosen would help future maintainers understand the workload distribution design.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +1517 to +1518
// Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT
bool eligible_for_rht_cast_fusion =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider documenting cols % 128 requirement

While confirmed intentional in previous threads, adding a comment explaining why the dense kernel requires cols % 128 (likely UMMA tile alignment) would prevent future confusion, especially since MOE uses different alignment.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from 39f9272 to fd29f6b Compare January 29, 2026 00:39
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zhongbozhu
Copy link
Collaborator Author

/te-ci arm L1

@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from fd29f6b to 0864f99 Compare January 30, 2026 19:36
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zhongbozhu
Copy link
Collaborator Author

@ptrendx can we merge this PR soon?

// and inconsistently implemented.
// What math is accelerated? Only the high precision math, so numerical impact is minimal
// 1. replace x / y by x * (1/y)
// 2. replace 1 / x by reciporal_approximate_ftz(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point 2 is scary. I've heard that in the past there were issues with the FTZ setting for numerics.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, however, this part of fp32 math is pretty important for the fp4 quantization cost to be hidden under RHT gemm. Users are expected to know the danger and enable it with NVTE_USE_FAST_MATH.

int k_tile_size = 1024;

// TODO: add support for swizzle sf output
const bool use_swizzle_sf_output = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We then should check whether the output tensor requested swizzling and fail if so.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one has more complication, because of the requirement of triggering this fusion is M being 64 multiple instead of 128 multiple. However, the swizzle fusion does requires 128 multiple. For moe, we can zero pad, but we cannot zero pad for dense easily. Unless we lift up the requirement of triggering RHT fusion once for all, or we have some more dispatch logic about if we use this fusion, do we trigger swizzle fusion or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that is something we can solve in the followup, but at least for now we should at least fail explicitly if the output tensor has that option set to make sure that there is no silent wrong answer happening.

// 2. RHT is enabled
// 3. Columnwise usage is enabled
// 4. Rowwise and columnwise quantization are not fused,
// because within a single kernel we can generate two different random numbers for rowwise and columnwise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we make sure that when we request the RNG state that we tell it that we will generate twice as many numbers? I don't see any changes to that in the diff.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rng per-thread has been updated to be num_tensor * 1024.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And how many numbers do we actually need here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shouldn't have said num_tensor * 1024, since num_tensor==1.

Within the kernel, which is a persistent kernel, to generate RNG, it's already using the current data tile index, not the CTA index, so the rng number needed here only needs to be account for one CTA: https://github.com/zhongbozhu/TransformerEngine/blob/4f24746b81f011c826dca290f5cc9eb5732e4cf3/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu#L962

For this kernel, there are 4 warps handling colwise quantization, and 8 warps handling rowwise quantization. One tile to process a time is 128x128 elements. So in the worse case for colwise quantize, that's 128 * 128 / 4 / 32 = 128 elements touched per thread, 1024 is way above it.

zhongbozhu and others added 18 commits March 3, 2026 12:49
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from 4f24746 to e85d6de Compare March 3, 2026 23:44
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the zhongbo/dense_row_col_rht_fp4_quant branch from e85d6de to 3a93a72 Compare March 3, 2026 23:46
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/pytorch/csrc/quantizer.cpp, line 2292
Unnecessary intermediate buffer allocation when columnwise_usage=false

In the refactored unfused path, rht_output_t (a transposed buffer of shape cols × rows) is allocated unconditionally, even when columnwise_usage=false. In the old code this allocation was guarded inside if (columnwise_usage). The helper function itself skips the columnwise path when columnwise_usage=false, so this buffer is allocated but never used.

For large models (e.g. hidden_dim=4096, seq_len=2048, BF16) this wastes ~16MB of GPU memory per call and incurs unnecessary allocation overhead.

} else {
  auto& columnwise_quant_config_to_use =
      need_separate_columnwise_rng ? quant_config_columnwise : quant_config;
  // Only allocate the RHT transposed buffer if columnwise quantization is needed
  at::Tensor rht_output_t;
  TensorWrapper rht_output_t_cpp;
  if (columnwise_usage) {
    rht_output_t =
        allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
    rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
                                      std::vector<size_t>{cols, rows});
  }
  this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config,
                                         columnwise_quant_config_to_use, stream);
}

@zhongbozhu
Copy link
Collaborator Author

/te-ci arm L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (3)

transformer_engine/pytorch/csrc/quantizer.cpp, line 2289
Spurious memory allocation in unfused path when columnwise_usage=false

rht_output_t (a [cols × rows] BF16 tensor) is allocated unconditionally at lines 2282-2283 every time the unfused path is taken, even when columnwise_usage=false.

Inside quantize_with_rht_unfused_helper, this buffer is only accessed inside if (columnwise_usage) (line 2058), so for rowwise-only quantization the allocation is wasted. In the original code, this tensor was allocated inside if (columnwise_usage).

Consider guarding the allocation:

at::Tensor rht_output_t;
TensorWrapper rht_output_t_cpp;
if (columnwise_usage) {
    rht_output_t =
        allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
    rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
                                      std::vector<size_t>{cols, rows});
}
this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config,
                                       columnwise_quant_config_to_use, stream);

transformer_engine/pytorch/csrc/quantizer.cpp, line 2271
Fused kernel used for rowwise-only case, potentially regressing performance

When rowwise_usage=true && columnwise_usage=false && eligible_for_rht_cast_fusion, the new code calls the full Cutlass GEMM-based nvte_quantize_with_hadamard_transform fused kernel (line 2269), even though RHT computation is not needed.

The fused kernel carries significant overhead — TMA pipeline infrastructure, cluster-launch-control scheduler, 512-thread CTAs — when only rowwise quantization is required. For inference workloads that only request rowwise output, dispatching a heavier GEMM kernel is slower than the simpler nvte_quantize_v2 approach.

Consider short-circuiting to nvte_quantize_v2 for rowwise-only cases:

if (eligible_for_rht_cast_fusion && columnwise_usage) {
    // fused path: rowwise + RHT columnwise in one kernel
    NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0,
               "RHT matrix is not available.");
    auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix);
    NVTE_SCOPED_GIL_RELEASE({
      nvte_quantize_with_hadamard_transform(input.data(), out.data(), rht_matrix_nvte.data(),
                                            quant_config, stream);
    });
} else if (eligible_for_rht_cast_fusion && !columnwise_usage) {
    // rowwise only – simpler kernel is sufficient
    NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
} else {
    // unfused fallback path
    ...
}

transformer_engine/pytorch/csrc/quantizer.cpp, line 2257
NVTE_USE_FAST_MATH env-var read on every quantize_impl call

transformer_engine::getenv<bool>("NVTE_USE_FAST_MATH") is called inside quantize_impl (line 2253), which executes on every forward/backward pass. For high-throughput training loops, this adds a getenv syscall on every kernel dispatch.

Since the PR notes that fast-math is "strongly recommended for performance", consider caching this value at construction time in the NVFP4Quantizer class or at module load time rather than re-reading it on every invocation.

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/pytorch/csrc/quantizer.cpp, line 2288
Unnecessary buffer allocation when columnwise_usage=false

In the unfused path, rht_output_t (a cols × rows GPU tensor) is now allocated unconditionally before calling quantize_with_rht_unfused_helper, even when columnwise_usage=false (rowwise-only quantization). Inside the helper, rht_output_t_cpp is only used inside the if (columnwise_usage) block, so the allocation is wasted for the rowwise-only case.

Before this refactoring the allocation was correctly guarded inside if (columnwise_usage). The regression means every odd-shaped input with RHT in rowwise-only mode now incurs an unnecessary GPU memory allocation of size cols × rows × sizeof(BF16).

      // unfused path also needs memory allocation for intermediate buffer for RHT output
      at::Tensor rht_output_t;  // The RHT(x_t) output, in columnwise layout
      // This wrapper is going to be passed as input to the quantization kernel.
      TensorWrapper rht_output_t_cpp;  // Wrapper to contain the RHT(x) and RHT(x_t) outputs
      if (columnwise_usage) {
        rht_output_t =
            allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
        // NOTE (frsun): This is non-intuitive, we are writing the
        // result of transposed RHT to the output of rowwise.
        rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
                                          std::vector<size_t>{cols, rows});
      }

#include <cuda_runtime.h>

#if CUDA_VERSION >= 12080
#include "common/common.h"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for having FP4_TYPE_SUPPORTED defined in this header file, which is later referenced in this ptx.cuh

Expect this fix to solve some compile issues for older CUDA versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants