Skip to content

[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680

Open
jberchtold-nvidia wants to merge 31 commits intoNVIDIA:mainfrom
jberchtold-nvidia:gmm
Open

[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680
jberchtold-nvidia wants to merge 31 commits intoNVIDIA:mainfrom
jberchtold-nvidia:gmm

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Feb 13, 2026

Description

Integrate new BF16 grouped GEMM from TE common/cuBLASLt that supports on-device group sizes without a D2H memcpy and stream sync. This grouped GEMM is faster and CUDA-graph safe.

Also fixes #2659

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a new cuda-graph-safe grouped GEMM to TE/JAX that is automatically used as the backend when the input data is bf16 (no scaling recipe) and not bias is required.
  • Exposed a new make_ragged_dot_cls for easy integration into existing models. This will be most useful when quantization is supported and storing recipe state is required

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft February 13, 2026 17:42
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 13, 2026

Greptile Summary

This PR integrates a new CUDA-graph-safe BF16 grouped GEMM backend (GroupedGemmV2FFI / nvte_grouped_gemm) into the JAX TE layer, replacing the older multi-stream path that required a D2H memcpy and stream synchronisation for plain BF16 / NO_SCALING inputs. It also exposes a make_grouped_dense_cls factory for easy drop-in use in existing Flax MoE models, and fixes a dimension-swap bug in the C++ grouped GEMM test.

Key changes:

  • New C++ backend (GroupedGemmV2FFI) dispatches to nvte_grouped_gemm (cuBLAS 13.2+) and is registered with FFI_CudaGraph_Traits, making it safe for CUDA graph capture.
  • Automatic dispatch in Python: _can_use_v2_grouped_gemm selects the V2 path for bfloat16 + NO_SCALING + no-bias; otherwise falls back to the legacy multi-stream path.
  • _v2_grouped_gemm_available is cached at module import time via a lightweight probe, avoiding per-trace overhead.
  • make_grouped_dense_cls provides a make_dot_general_cls-style factory; however, the internal Flax module is registered under "ragged_dot" instead of "grouped_dense", creating a checkpoint-path inconsistency that should be fixed.
  • GroupedGemmV2FFI carries over several dead variables and an unreachable FP8 architecture check from the legacy path that were not cleaned up during the port, which will generate compiler warnings.
  • del kwargs in te_grouped_dot_general silently discards any keyword arguments instead of validating or forwarding them.
  • Bug fixes: jnp.emptyjnp.ones for the uninitialized scale buffer in grouped_quantize; M/N dimension swap fixed in the C++ test; d_rows/d_cols swap fixed in the setup kernel.

Confidence Score: 3/5

  • Safe for BF16 inference/training but has a checkpoint-breaking naming bug in make_grouped_dense_cls that should be fixed before wider adoption.
  • The core CUDA-graph-safe GEMM path and dispatch logic appear correct and the PR includes test coverage. However, the "ragged_dot" internal module name in make_grouped_dense_cls is a logic-level inconsistency that will silently corrupt checkpoint parameter paths for any model using that factory. Additionally, GroupedGemmV2FFI has accumulated a significant number of dead variables and an unreachable code block that will produce compiler warnings and hinder maintainability. These issues lower confidence from an otherwise well-structured feature addition.
  • transformer_engine/jax/flax/module.py (ragged_dot naming bug) and transformer_engine/jax/csrc/extensions/gemm.cpp (dead variables, unreachable FP8 check).

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/gemm.cpp Adds new GroupedGemmV2FFI and supporting JAXX_GroupedTensorWrapper class for CUDA-graph-safe grouped GEMM. The implementation has several dead variables (lhs_ptr, rhs_ptr, out_ptr, num_math_sm, grad, accumulate, use_split_accumulator, bias_shape, out_dtype_bytes) and an unreachable FP8 arch check that were carried over from the legacy path. A redundant lhs_is_trans = true assignment in the wgrad branch is also present.
transformer_engine/jax/flax/module.py Adds make_grouped_dense_cls factory function and n_groups support in quantizer-set generation. Has two bugs: (1) the internal Flax module is registered as "ragged_dot" instead of "grouped_dense", causing checkpoint path inconsistency; (2) **kwargs are silently discarded rather than validated or forwarded.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Adds compute_grouped_tensor_offset device helper with cumulative-sum for per-tensor dims, exposes nvte_get_grouped_gemm_setup_workspace_size, and adds nvte_convert_int32_to_int64 utility kernel. Also fixes d_rows/d_cols swap. Code is in the correct preprocessor guards; the cuBLAS stubs are properly placed in the #else block.
transformer_engine/jax/cpp_extensions/gemm.py Integrates the V2 CUDA-graph-safe backend into the JAX primitive layer; adds _can_use_v2_grouped_gemm dispatch helper and _v2_grouped_gemm_available cached flag at import time. The dual-backend GroupedGemmPrimitive cleanly dispatches to either te_grouped_gemm_ffi or te_grouped_gemm_v2_ffi depending on dtype/scaling-mode.
transformer_engine/jax/cpp_extensions/quantization.py Fixes uninitialized scale buffer (jnp.emptyjnp.ones) and improves assertion message formatting. Both changes are correct and safe.
tests/cpp/operator/test_grouped_gemm.cu Fixes M/N dimension swap in test shape construction for transposed GEMM cases; shapes now match the expected A[N,K]/B[K,M] layout that corresponds to the fix in #2659.

Sequence Diagram

sequenceDiagram
    participant PY as grouped_gemm() [Python]
    participant PRIM as GroupedGemmPrimitive
    participant V2FFI as GroupedGemmV2FFI [C++]
    participant V1FFI as GroupedGemmFFI [C++]
    participant NVTE as nvte_grouped_gemm [cuBLAS 13.2+]
    participant MULTI as nvte_multi_tensor_gemm [legacy]

    PY->>PY: _can_use_v2_grouped_gemm(scaling_mode, dtype, has_bias)
    alt BF16 + NO_SCALING + no bias + cuBLAS≥13.2
        PY->>PRIM: bind(use_v2_ffi=True, alpha, beta)
        PRIM->>V2FFI: te_grouped_gemm_v2_ffi (FFI_CudaGraph_Traits)
        V2FFI->>V2FFI: nvte_convert_int32_to_int64(group_sizes)
        V2FFI->>V2FFI: build JAXX_GroupedTensorWrapper objects
        V2FFI->>NVTE: nvte_grouped_gemm(rhs, lhs, out, α, β, ws_setup, ws_cublas)
        NVTE-->>V2FFI: result (CUDA-graph safe)
        V2FFI-->>PY: output
    else FP8 / MXFP8 / has_bias / old cuBLAS
        PY->>PRIM: bind(use_v2_ffi=False, group_offset)
        PRIM->>V1FFI: te_grouped_gemm_ffi (multi-stream)
        V1FFI->>V1FFI: D2H memcpy + stream sync for group_sizes
        V1FFI->>MULTI: nvte_multi_tensor_gemm(per-stream)
        MULTI-->>V1FFI: result
        V1FFI-->>PY: output
    end
Loading

Last reviewed commit: 505feac

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

jberchtold-nvidia and others added 7 commits February 24, 2026 13:10
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…void enabling JAX x64 globally

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…sion to a different suffix

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
jberchtold-nvidia and others added 3 commits February 24, 2026 13:37
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

right cuBLAS version

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review February 25, 2026 21:32
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +448 to +464
__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta,
size_t idx) {
if (meta.offsets) {
return meta.offsets[idx];
} else if (meta.first_dims != nullptr || meta.last_dims != nullptr) {
// offset[i] = sum_{j < i} (first_dims[j] * last_dims[j])
int64_t cumsum = 0;
for (size_t i = 0; i < idx; i++) {
int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first;
int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last;
cumsum += f * l;
}
return cumsum;
} else {
return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O(n²) complexity in parallel kernel. Each of the n threads calls this function with a different idx, and for case 2 (per-tensor dims without explicit offsets), thread idx performs a sequential loop from 0 to idx-1. This creates O(1 + 2 + ... + n) = O(n²) total work across all threads.

For large numbers of groups, consider either:

  1. Computing offsets once on CPU and passing them explicitly
  2. Using a parallel prefix sum (scan) to compute cumulative offsets
  3. Documenting this limitation if group counts are expected to be small

jberchtold-nvidia and others added 4 commits March 3, 2026 12:58
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Comment on lines +1464 to +1466
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "ragged_dot"
)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_ragged_dot_cls returns an instance, not a class, contradicting its name and docstring.

The function is documented as "Creates a ragged dot (grouped GEMM) class" and follows the make_*_cls naming convention, but the trailing () instantiates the result. This is inconsistent with make_dot_general_cls (line 1443), which returns the class itself:

# make_dot_general_cls returns a class:
DotGeneral = make_dot_general_cls(recipe)
result = DotGeneral()(x, kernel)          # instantiate → call

# make_ragged_dot_cls (as written) returns an instance:
ragged_dot = make_ragged_dot_cls(None)    # already an instance!
result = ragged_dot(x, kernel, group_sizes)  # just call

This inconsistency will confuse users who expect API parity with make_dot_general_cls. Either:

  1. Remove the trailing () so the function returns the class (consistent with the name/docs), or
  2. Rename to make_ragged_dot to accurately reflect that it returns a ready-to-use module instance.
Suggested change
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "ragged_dot"
)()
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "ragged_dot"
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call it grouped_dense here to be consistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! There's an inconsistency in terminology. The function calls grouped_dense internally (and the underlying implementation in dense.py is named grouped_dense), but the wrapper is called make_ragged_dot_cls.

Two options for consistency:

  1. Rename to make_grouped_dense_cls - matches the underlying implementation and would align with the grouped_dense function already exported from the JAX API
  2. Keep make_ragged_dot_cls - if "ragged dot" is intended as the user-facing terminology for this feature (though it's currently not used elsewhere in the codebase)

Which terminology do you want to standardize on - "ragged_dot" or "grouped_dense"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to rename to grouped_dense

Comment on lines +1997 to +2003
try:
get_grouped_gemm_setup_workspace_size(1)
except RuntimeError as e:
if "cublas" in str(e).lower():
# If the workspace size function is not available, it means the cuda-graphable implementation is not available.
return False
raise e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragile capability detection via exception message matching.

Using try/except with a substring match ("cublas" in str(e).lower()) for runtime capability detection is fragile:

  1. Substring fragility: If the error message ever changes (e.g., capitalization, wording), the if branch fails silently and the detection breaks completely, changing the fallback behavior.
  2. Called every JAX trace: This check runs on every invocation of _can_use_cuda_graphable_grouped_gemm during abstract evaluation, adding unnecessary overhead for what is a build-time constant.

A more robust approach would be to cache the result at module import time:

# Computed once at import; `None` signals "not yet tested".
_CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE: Optional[bool] = None

def _check_cuda_graphable_grouped_gemm_available() -> bool:
    global _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE
    if _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE is None:
        try:
            get_grouped_gemm_setup_workspace_size(1)
            _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE = True
        except RuntimeError as e:
            if "cublas" in str(e).lower():
                _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE = False
            else:
                raise
    return _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this helper.

Why don't we simply check for NO_SCALING, then go with the v2?

Also, we should be able to check it inside primitive methods instead of passing it as an argument so that when we eventually remove the multi-stream approach, we don't need to change the primitive inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points on both fronts.

1. Simplifying the check:
You're right - if we're only supporting BF16 + NO_SCALING in v2 initially, the try/except adds complexity for runtime capability detection that may not be needed. The cublas version check could be done once at module import or handled by the FFI call itself (which already has #if CUBLAS_VERSION >= 130200 guards).

2. Checking inside primitives:
This is a better design. Instead of passing use_cuda_graphable_ffi through all the primitive signatures, you could:

  • Check scaling_mode, dtype, has_bias directly inside lowering() to decide which FFI to call
  • Or make it purely a dispatch decision in the C++ layer based on the inputs

This way when you eventually remove the multi-stream path, you just delete the old FFI handler and the dispatch logic - no need to change primitive inputs/outputs or all the call sites.

The runtime check (if still needed) could be a module-level _HAS_CUDA_GRAPHABLE_GEMM cached at import time rather than being checked on every trace.

Copy link
Collaborator Author

@jberchtold-nvidia jberchtold-nvidia Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we simply check for NO_SCALING, then go with the v2?

In addition to NO_SCALING, we need to check the cuBLAS version. In fact, more than that, we need to check the cuBLAS version TE was built with, which may not be the same depending on how the user installed TE. Otherwise, we will dispatch to v2 but it will fail during lowering because the C++ functions raise an error in TE core based on the cuBLAS version it was compiled with. In this case, instead of erroring, I'd rather keep the old behavior and fall back to the slower implementation.

Also, we should be able to check it inside primitive methods instead of passing it as an argument so that when we eventually remove the multi-stream approach, we don't need to change the primitive inputs.

I was planning to do this originally but ran into two issues:

  1. I needed to check which backend we're using (original or V2) in multiple places in the primitive. The abstract requires it and the lowering requires it. In the lowering, we don't have easy access to dtype as it's not a jnp.dtype it's some mlir dtype that isn't as easy to convert.
  2. The MLIR ctx bakes in information about the inputs. You can't simply slice args in lowering and pick and choose which primitive to dispatch to. If you do, it'll give an MLIR error that the ctx and args don't match. As a result, I'd have to update the original FFI to accept alpha and beta which would break FFI compatibility. Which is why I've gone with the solution of passing additional_arg_0 and additional_arg_1 which needs to happen outside the primitive. This isn't ideal, but was required to get around this fixed arg limitation. Additionally, the static arg indices are fixed at primitive registration and currently the original FFI and the V2 FFI take a different number of args, so for the original FFI I have to use an unused addtional_arg_1 in the outer impl.

As a result of both of these, it's simpler to just check once outside the primitive whether we should use the original or V2, then pass that as a bool flag. Given these restrictions of 1. and 2., I'm not sure this consolidated primitive is a cleaner solution than the split primitives (unless I can break FFI compatibility to unify the FFIs to make the original FFI take alpha and beta).

Comment on lines +556 to +557
size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) {
return grouped_gemm_setup_workspace_size(num_tensors);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Register nvte_ API with NVTE_API_CALL.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +784 to +786
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"),
Copy link
Collaborator

@phu0ngng phu0ngng Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these three attributes as they are not needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed has_bias and use_async_d2h_group_sizes. The is_grouped_dense_wgrad is still required currently but I agree in future we should refactor this logic

return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmCudaGraphableHandler, GroupedGemmCudaGraphableFFI,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we simply call it v2 for now and eventually rename it back to GroupedGEMM?

Besides, if it is possible, could we mark it as an experimental API and subject to changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, renamed to V2

Comment on lines +1997 to +2003
try:
get_grouped_gemm_setup_workspace_size(1)
except RuntimeError as e:
if "cublas" in str(e).lower():
# If the workspace size function is not available, it means the cuda-graphable implementation is not available.
return False
raise e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this helper.

Why don't we simply check for NO_SCALING, then go with the v2?

Also, we should be able to check it inside primitive methods instead of passing it as an argument so that when we eventually remove the multi-stream approach, we don't need to change the primitive inputs.

Comment on lines +1464 to +1466
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "ragged_dot"
)()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call it grouped_dense here to be consistent?

* \param[in] num_tensors Number of tensors (GEMMs) in the group.
* \return Required size in bytes for workspace_setup.
*/
size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could we add get i.e., nvte_get_grouped_gemm_setup_workspace_size? :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

jberchtold-nvidia and others added 3 commits March 3, 2026 17:09
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
jberchtold-nvidia and others added 3 commits March 4, 2026 10:06
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Inconsistent CUDA version in upgrade hint

The error message says "Please upgrade to CUDA 13.1 or newer", but the corresponding stub for nvte_grouped_gemm_setup_workspace_size (line 673 in the same #else block) says "Please upgrade to CUDA 13.2 or newer". Both stubs guard the same CUBLAS_VERSION >= 130200 requirement, so the version suggested to users is contradictory depending on which function they call first. Consider aligning both messages to the same version string.

  NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ",
             CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer.");

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
O(n²) work in device kernel for case 2

The compute_grouped_tensor_offset helper is called from setup_grouped_gemm_kernel which launches one thread per group. For case 2 (per-tensor dims, no explicit offset array), thread idx serially loops from 0 to idx-1 to accumulate the cumulative sum. The total work across all n threads is 0 + 1 + 2 + … + (n-1) = O(n²).

For a typical MoE with 64 or 128 experts this is tolerable, but for larger group counts the kernel becomes the bottleneck. Note that case 2 is reachable from GroupedGemmV2FFI whenever first_dims is set without offsets, which is exactly the is_grouped_dense_wgrad path (set_group_sizes_only populates first_dims but not offsets).

Consider either:

  1. Computing offsets once with a parallel prefix-sum (scan) kernel before launching setup_grouped_gemm_kernel, or
  2. Documenting an explicit upper bound on supported group counts for this path.

transformer_engine/jax/csrc/extensions/gemm.cpp
Redundant assignment to lhs_is_trans

lhs_is_trans is a bool function parameter. The NVTE_CHECK three lines above already asserts that lhs_is_trans == true, so the explicit re-assignment on this line is dead code that may confuse future readers into thinking the caller can pass lhs_is_trans = false here.

    // lhs_is_trans is already guaranteed true by the NVTE_CHECK above.
    auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape);

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Comment on lines +1464 to +1466
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "ragged_dot"
)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent internal module name "ragged_dot" in make_grouped_dense_cls

The wrap_function_in_te_state_module call passes "ragged_dot" as the Flax module name, but the public-facing function was renamed to make_grouped_dense_cls. This string becomes the Flax variable-collection name (used in init / apply param dicts and serialised checkpoints). Any model that saves weights with this wrapper will have "ragged_dot" in its checkpoint paths, which is both confusing and will silently break loading if the name is ever corrected in a follow-up PR.

Suggested change
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "ragged_dot"
)()
return wrap_function_in_te_state_module(
te_grouped_dot_general, quantization_recipe, "grouped_dense"
)()

Comment on lines +1450 to +1451
def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs):
del kwargs # Unused
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

del kwargs silently discards user-provided arguments

Any keyword argument passed to the returned module's __call__ (e.g. precision, preferred_element_type, or future contracting_dims) is silently swallowed rather than forwarded to grouped_dense or rejected with an informative error. Users who accidentally pass such arguments will get incorrect behaviour with no warning. Consider at minimum asserting that kwargs is empty:

def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs):
    assert not kwargs, f"Unexpected keyword arguments: {list(kwargs.keys())}"

Comment on lines +584 to +691
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type());
bool has_bias = product(bias.dimensions()) > 0;
auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr;
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());

NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];

// Convert int32 group_sizes to int64 into the dedicated output buffer.
NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32.");
auto *int64_sizes_ptr = reinterpret_cast<int64_t *>(int64_workspace->untyped_data());
nvte_convert_int32_to_int64(reinterpret_cast<const int32_t *>(group_sizes.untyped_data()),
int64_sizes_ptr, num_gemms, stream);

NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
"Only non-quantized grouped GEMM is supported in current implementation.");

// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;

// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto setup_workspace_ptr = reinterpret_cast<uint8_t *>(setup_workspace->untyped_data());
// Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto cublas_workspace_ptr = reinterpret_cast<uint8_t *>(cublas_workspace->untyped_data());
cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr);
auto workspace_total_size = product(cublas_workspace->dimensions());

auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
const size_t workspace_alignment_padding = 256;
const size_t tensor_scaling_sinv_aligment = 16;
const size_t mxfp8_scaling_sinv_alignment_padding = 256;
auto workspace_size = workspace_total_size - workspace_alignment_padding;
if (is_mxfp8_scaling) {
// For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4.
workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding);
} else if (is_tensor_scaling) {
// For tensor scaling, each matrix has a single scale value, and all scales need to be aligned
// by 16 bytes to meet the requirement of CUDA 12.9.1 and later.
workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size);
}
auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size;
swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr);
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr);
auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned
auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment;

size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);

NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");

size_t expected_lhs_size = m * k;
size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n);
size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n);
size_t actual_lhs_size = product(lhs_data.dimensions());
size_t actual_rhs_size = product(rhs_data.dimensions());
size_t actual_out_size = product(output->dimensions());
NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ",
expected_lhs_size, ", got ", actual_lhs_size);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(expected_rhs_size == actual_rhs_size,
"Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k,
" = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m,
" * ", n, " = ", expected_out_size, ", got ", actual_out_size);
} else {
NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k,
" * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size,
"Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n,
" = ", expected_out_size, ", got ", actual_out_size);
}

auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false;
bool accumulate = false;
bool use_split_accumulator = false;
auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch();

if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
"got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiple dead variables and unreachable block in GroupedGemmV2FFI

Several variables computed in this function are never consumed by the final nvte_grouped_gemm call (the V2 API takes grouped-tensor wrappers, not raw pointers or per-stream bookkeeping), and will generate -Wunused-variable compiler warnings:

  • lhs_ptr, rhs_ptr (lines 585–586) — raw pointers, never passed anywhere
  • lhs_sinv_ptr, rhs_sinv_ptr (lines 587–588) — same
  • out_ptr (line 618) — same
  • num_math_sm, grad, accumulate, use_split_accumulator, bias_shape (lines 680–684) — multi-stream bookkeeping that is unused in the V2 path
  • out_dtype_bytes (line 652) — never referenced

Additionally, the FP8 architecture check (lines 687–691) is unreachable: the NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, ...) on line 606 guarantees is_fp8_gemm is always false in this function. The whole block can be removed or marked with a // TODO: wire up when FP8 is added comment alongside the other FP8 placeholders.

These should be cleaned up to avoid misleading future contributors and to keep compiler warnings clean.

Comment on lines +715 to +717
//// LHS
NVTEShape lhsShape{.data = {k, m}, .ndim = 2};
lhs_is_trans = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant lhs_is_trans = true assignment

The NVTE_CHECK(lhs_is_trans && !rhs_is_trans, ...) guard on line 707 already guarantees that lhs_is_trans is true when execution reaches line 717. The reassignment is dead code and may confuse readers into thinking the shape construction depends on overriding a possibly-false value.

Suggested change
//// LHS
NVTEShape lhsShape{.data = {k, m}, .ndim = 2};
lhs_is_trans = true;
//// LHS
NVTEShape lhsShape{.data = {k, m}, .ndim = 2};
// lhs_is_trans is already verified true by the NVTE_CHECK above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Core] TE common nvte_grouped_gemm treats output layout as column-wise instead of rowwise

2 participants