[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680
[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680jberchtold-nvidia wants to merge 31 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR integrates a new CUDA-graph-safe BF16 grouped GEMM backend ( Key changes:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as grouped_gemm() [Python]
participant PRIM as GroupedGemmPrimitive
participant V2FFI as GroupedGemmV2FFI [C++]
participant V1FFI as GroupedGemmFFI [C++]
participant NVTE as nvte_grouped_gemm [cuBLAS 13.2+]
participant MULTI as nvte_multi_tensor_gemm [legacy]
PY->>PY: _can_use_v2_grouped_gemm(scaling_mode, dtype, has_bias)
alt BF16 + NO_SCALING + no bias + cuBLAS≥13.2
PY->>PRIM: bind(use_v2_ffi=True, alpha, beta)
PRIM->>V2FFI: te_grouped_gemm_v2_ffi (FFI_CudaGraph_Traits)
V2FFI->>V2FFI: nvte_convert_int32_to_int64(group_sizes)
V2FFI->>V2FFI: build JAXX_GroupedTensorWrapper objects
V2FFI->>NVTE: nvte_grouped_gemm(rhs, lhs, out, α, β, ws_setup, ws_cublas)
NVTE-->>V2FFI: result (CUDA-graph safe)
V2FFI-->>PY: output
else FP8 / MXFP8 / has_bias / old cuBLAS
PY->>PRIM: bind(use_v2_ffi=False, group_offset)
PRIM->>V1FFI: te_grouped_gemm_ffi (multi-stream)
V1FFI->>V1FFI: D2H memcpy + stream sync for group_sizes
V1FFI->>MULTI: nvte_multi_tensor_gemm(per-stream)
MULTI-->>V1FFI: result
V1FFI-->>PY: output
end
Last reviewed commit: 505feac |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…void enabling JAX x64 globally Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…sion to a different suffix Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
a4b6769 to
a661e9e
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L0 jax |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
154f2be to
661a829
Compare
|
/te-ci L0 jax |
right cuBLAS version Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
7b125cc to
bd5e6fb
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
| __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta, | ||
| size_t idx) { | ||
| if (meta.offsets) { | ||
| return meta.offsets[idx]; | ||
| } else if (meta.first_dims != nullptr || meta.last_dims != nullptr) { | ||
| // offset[i] = sum_{j < i} (first_dims[j] * last_dims[j]) | ||
| int64_t cumsum = 0; | ||
| for (size_t i = 0; i < idx; i++) { | ||
| int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; | ||
| int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; | ||
| cumsum += f * l; | ||
| } | ||
| return cumsum; | ||
| } else { | ||
| return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last; | ||
| } | ||
| } |
There was a problem hiding this comment.
O(n²) complexity in parallel kernel. Each of the n threads calls this function with a different idx, and for case 2 (per-tensor dims without explicit offsets), thread idx performs a sequential loop from 0 to idx-1. This creates O(1 + 2 + ... + n) = O(n²) total work across all threads.
For large numbers of groups, consider either:
- Computing offsets once on CPU and passing them explicitly
- Using a parallel prefix sum (scan) to compute cumulative offsets
- Documenting this limitation if group counts are expected to be small
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
| return wrap_function_in_te_state_module( | ||
| te_grouped_dot_general, quantization_recipe, "ragged_dot" | ||
| )() |
There was a problem hiding this comment.
make_ragged_dot_cls returns an instance, not a class, contradicting its name and docstring.
The function is documented as "Creates a ragged dot (grouped GEMM) class" and follows the make_*_cls naming convention, but the trailing () instantiates the result. This is inconsistent with make_dot_general_cls (line 1443), which returns the class itself:
# make_dot_general_cls returns a class:
DotGeneral = make_dot_general_cls(recipe)
result = DotGeneral()(x, kernel) # instantiate → call
# make_ragged_dot_cls (as written) returns an instance:
ragged_dot = make_ragged_dot_cls(None) # already an instance!
result = ragged_dot(x, kernel, group_sizes) # just callThis inconsistency will confuse users who expect API parity with make_dot_general_cls. Either:
- Remove the trailing
()so the function returns the class (consistent with the name/docs), or - Rename to
make_ragged_dotto accurately reflect that it returns a ready-to-use module instance.
| return wrap_function_in_te_state_module( | |
| te_grouped_dot_general, quantization_recipe, "ragged_dot" | |
| )() | |
| return wrap_function_in_te_state_module( | |
| te_grouped_dot_general, quantization_recipe, "ragged_dot" | |
| ) |
There was a problem hiding this comment.
Should we call it grouped_dense here to be consistent?
There was a problem hiding this comment.
Good catch! There's an inconsistency in terminology. The function calls grouped_dense internally (and the underlying implementation in dense.py is named grouped_dense), but the wrapper is called make_ragged_dot_cls.
Two options for consistency:
- Rename to
make_grouped_dense_cls- matches the underlying implementation and would align with thegrouped_densefunction already exported from the JAX API - Keep
make_ragged_dot_cls- if "ragged dot" is intended as the user-facing terminology for this feature (though it's currently not used elsewhere in the codebase)
Which terminology do you want to standardize on - "ragged_dot" or "grouped_dense"?
There was a problem hiding this comment.
Updated to rename to grouped_dense
| try: | ||
| get_grouped_gemm_setup_workspace_size(1) | ||
| except RuntimeError as e: | ||
| if "cublas" in str(e).lower(): | ||
| # If the workspace size function is not available, it means the cuda-graphable implementation is not available. | ||
| return False | ||
| raise e |
There was a problem hiding this comment.
Fragile capability detection via exception message matching.
Using try/except with a substring match ("cublas" in str(e).lower()) for runtime capability detection is fragile:
- Substring fragility: If the error message ever changes (e.g., capitalization, wording), the
ifbranch fails silently and the detection breaks completely, changing the fallback behavior. - Called every JAX trace: This check runs on every invocation of
_can_use_cuda_graphable_grouped_gemmduring abstract evaluation, adding unnecessary overhead for what is a build-time constant.
A more robust approach would be to cache the result at module import time:
# Computed once at import; `None` signals "not yet tested".
_CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE: Optional[bool] = None
def _check_cuda_graphable_grouped_gemm_available() -> bool:
global _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE
if _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE is None:
try:
get_grouped_gemm_setup_workspace_size(1)
_CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE = True
except RuntimeError as e:
if "cublas" in str(e).lower():
_CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLE = False
else:
raise
return _CUDA_GRAPHABLE_GROUPED_GEMM_AVAILABLEThere was a problem hiding this comment.
I'm not sure about this helper.
Why don't we simply check for NO_SCALING, then go with the v2?
Also, we should be able to check it inside primitive methods instead of passing it as an argument so that when we eventually remove the multi-stream approach, we don't need to change the primitive inputs.
There was a problem hiding this comment.
Good points on both fronts.
1. Simplifying the check:
You're right - if we're only supporting BF16 + NO_SCALING in v2 initially, the try/except adds complexity for runtime capability detection that may not be needed. The cublas version check could be done once at module import or handled by the FFI call itself (which already has #if CUBLAS_VERSION >= 130200 guards).
2. Checking inside primitives:
This is a better design. Instead of passing use_cuda_graphable_ffi through all the primitive signatures, you could:
- Check
scaling_mode,dtype,has_biasdirectly insidelowering()to decide which FFI to call - Or make it purely a dispatch decision in the C++ layer based on the inputs
This way when you eventually remove the multi-stream path, you just delete the old FFI handler and the dispatch logic - no need to change primitive inputs/outputs or all the call sites.
The runtime check (if still needed) could be a module-level _HAS_CUDA_GRAPHABLE_GEMM cached at import time rather than being checked on every trace.
There was a problem hiding this comment.
Why don't we simply check for NO_SCALING, then go with the v2?
In addition to NO_SCALING, we need to check the cuBLAS version. In fact, more than that, we need to check the cuBLAS version TE was built with, which may not be the same depending on how the user installed TE. Otherwise, we will dispatch to v2 but it will fail during lowering because the C++ functions raise an error in TE core based on the cuBLAS version it was compiled with. In this case, instead of erroring, I'd rather keep the old behavior and fall back to the slower implementation.
Also, we should be able to check it inside primitive methods instead of passing it as an argument so that when we eventually remove the multi-stream approach, we don't need to change the primitive inputs.
I was planning to do this originally but ran into two issues:
- I needed to check which backend we're using (original or V2) in multiple places in the primitive. The abstract requires it and the lowering requires it. In the lowering, we don't have easy access to dtype as it's not a jnp.dtype it's some mlir dtype that isn't as easy to convert.
- The MLIR
ctxbakes in information about the inputs. You can't simply sliceargsin lowering and pick and choose which primitive to dispatch to. If you do, it'll give an MLIR error that thectxandargsdon't match. As a result, I'd have to update the original FFI to acceptalphaandbetawhich would break FFI compatibility. Which is why I've gone with the solution of passing additional_arg_0 and additional_arg_1 which needs to happen outside the primitive. This isn't ideal, but was required to get around this fixed arg limitation. Additionally, the static arg indices are fixed at primitive registration and currently the original FFI and the V2 FFI take a different number of args, so for the original FFI I have to use an unused addtional_arg_1 in the outer impl.
As a result of both of these, it's simpler to just check once outside the primitive whether we should use the original or V2, then pass that as a bool flag. Given these restrictions of 1. and 2., I'm not sure this consolidated primitive is a cleaner solution than the split primitives (unless I can break FFI compatibility to unify the FFIs to make the original FFI take alpha and beta).
| size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors) { | ||
| return grouped_gemm_setup_workspace_size(num_tensors); |
There was a problem hiding this comment.
Register nvte_ API with NVTE_API_CALL.
| .Attr<bool>("has_bias") | ||
| .Attr<bool>("is_grouped_dense_wgrad") | ||
| .Attr<bool>("use_async_d2h_group_sizes"), |
There was a problem hiding this comment.
Remove these three attributes as they are not needed.
There was a problem hiding this comment.
Removed has_bias and use_async_d2h_group_sizes. The is_grouped_dense_wgrad is still required currently but I agree in future we should refactor this logic
| return ffi_with_cuda_error_check(); | ||
| } | ||
|
|
||
| XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmCudaGraphableHandler, GroupedGemmCudaGraphableFFI, |
There was a problem hiding this comment.
Should we simply call it v2 for now and eventually rename it back to GroupedGEMM?
Besides, if it is possible, could we mark it as an experimental API and subject to changes?
There was a problem hiding this comment.
Sounds good, renamed to V2
| try: | ||
| get_grouped_gemm_setup_workspace_size(1) | ||
| except RuntimeError as e: | ||
| if "cublas" in str(e).lower(): | ||
| # If the workspace size function is not available, it means the cuda-graphable implementation is not available. | ||
| return False | ||
| raise e |
There was a problem hiding this comment.
I'm not sure about this helper.
Why don't we simply check for NO_SCALING, then go with the v2?
Also, we should be able to check it inside primitive methods instead of passing it as an argument so that when we eventually remove the multi-stream approach, we don't need to change the primitive inputs.
| return wrap_function_in_te_state_module( | ||
| te_grouped_dot_general, quantization_recipe, "ragged_dot" | ||
| )() |
There was a problem hiding this comment.
Should we call it grouped_dense here to be consistent?
| * \param[in] num_tensors Number of tensors (GEMMs) in the group. | ||
| * \return Required size in bytes for workspace_setup. | ||
| */ | ||
| size_t nvte_grouped_gemm_setup_workspace_size(size_t num_tensors); |
There was a problem hiding this comment.
Nit: could we add get i.e., nvte_get_grouped_gemm_setup_workspace_size? :D
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
for more information, see https://pre-commit.ci
Additional Comments (1)
The error message says |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Additional Comments (2)
The For a typical MoE with 64 or 128 experts this is tolerable, but for larger group counts the kernel becomes the bottleneck. Note that case 2 is reachable from Consider either:
|
for more information, see https://pre-commit.ci
|
/te-ci |
| return wrap_function_in_te_state_module( | ||
| te_grouped_dot_general, quantization_recipe, "ragged_dot" | ||
| )() |
There was a problem hiding this comment.
Inconsistent internal module name "ragged_dot" in make_grouped_dense_cls
The wrap_function_in_te_state_module call passes "ragged_dot" as the Flax module name, but the public-facing function was renamed to make_grouped_dense_cls. This string becomes the Flax variable-collection name (used in init / apply param dicts and serialised checkpoints). Any model that saves weights with this wrapper will have "ragged_dot" in its checkpoint paths, which is both confusing and will silently break loading if the name is ever corrected in a follow-up PR.
| return wrap_function_in_te_state_module( | |
| te_grouped_dot_general, quantization_recipe, "ragged_dot" | |
| )() | |
| return wrap_function_in_te_state_module( | |
| te_grouped_dot_general, quantization_recipe, "grouped_dense" | |
| )() |
| def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): | ||
| del kwargs # Unused |
There was a problem hiding this comment.
del kwargs silently discards user-provided arguments
Any keyword argument passed to the returned module's __call__ (e.g. precision, preferred_element_type, or future contracting_dims) is silently swallowed rather than forwarded to grouped_dense or rejected with an informative error. Users who accidentally pass such arguments will get incorrect behaviour with no warning. Consider at minimum asserting that kwargs is empty:
def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs):
assert not kwargs, f"Unexpected keyword arguments: {list(kwargs.keys())}"| // Inputs | ||
| auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data()); | ||
| auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data()); | ||
| auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data()); | ||
| auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data()); | ||
| auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); | ||
| auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); | ||
| auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); | ||
| auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); | ||
| bool has_bias = product(bias.dimensions()) > 0; | ||
| auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr; | ||
| auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); | ||
|
|
||
| NVTE_CHECK(group_sizes.dimensions().size() == 1); | ||
| size_t num_gemms = group_sizes.dimensions()[0]; | ||
|
|
||
| // Convert int32 group_sizes to int64 into the dedicated output buffer. | ||
| NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); | ||
| auto *int64_sizes_ptr = reinterpret_cast<int64_t *>(int64_workspace->untyped_data()); | ||
| nvte_convert_int32_to_int64(reinterpret_cast<const int32_t *>(group_sizes.untyped_data()), | ||
| int64_sizes_ptr, num_gemms, stream); | ||
|
|
||
| NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, | ||
| "Only non-quantized grouped GEMM is supported in current implementation."); | ||
|
|
||
| // It is weird that TE/Common GEMM only use colwise for MXFP8 | ||
| const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); | ||
| const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || | ||
| scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; | ||
| const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; | ||
| const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; | ||
| const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; | ||
|
|
||
| // Outputs | ||
| auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data()); | ||
| auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); | ||
| auto setup_workspace_ptr = reinterpret_cast<uint8_t *>(setup_workspace->untyped_data()); | ||
| // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned | ||
| auto cublas_workspace_ptr = reinterpret_cast<uint8_t *>(cublas_workspace->untyped_data()); | ||
| cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); | ||
| auto workspace_total_size = product(cublas_workspace->dimensions()); | ||
|
|
||
| auto lhs_sinv_size = product(lhs_sinv.dimensions()); | ||
| auto rhs_sinv_size = product(rhs_sinv.dimensions()); | ||
| const size_t workspace_alignment_padding = 256; | ||
| const size_t tensor_scaling_sinv_aligment = 16; | ||
| const size_t mxfp8_scaling_sinv_alignment_padding = 256; | ||
| auto workspace_size = workspace_total_size - workspace_alignment_padding; | ||
| if (is_mxfp8_scaling) { | ||
| // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. | ||
| workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); | ||
| } else if (is_tensor_scaling) { | ||
| // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned | ||
| // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. | ||
| workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); | ||
| } | ||
| auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; | ||
| swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); | ||
| auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; | ||
| swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); | ||
| auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned | ||
| auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; | ||
|
|
||
| size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); | ||
| size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); | ||
| size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); | ||
| size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); | ||
| size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); | ||
| size_t out_dtype_bytes = te_dtype_bytes(out_dtype); | ||
|
|
||
| NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); | ||
| NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, | ||
| "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); | ||
|
|
||
| size_t expected_lhs_size = m * k; | ||
| size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); | ||
| size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); | ||
| size_t actual_lhs_size = product(lhs_data.dimensions()); | ||
| size_t actual_rhs_size = product(rhs_data.dimensions()); | ||
| size_t actual_out_size = product(output->dimensions()); | ||
| NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", | ||
| expected_lhs_size, ", got ", actual_lhs_size); | ||
| if (!is_grouped_dense_wgrad) { | ||
| NVTE_CHECK(expected_rhs_size == actual_rhs_size, | ||
| "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, | ||
| " = ", expected_rhs_size, ", got ", actual_rhs_size); | ||
| NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, | ||
| " * ", n, " = ", expected_out_size, ", got ", actual_out_size); | ||
| } else { | ||
| NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, | ||
| " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); | ||
| NVTE_CHECK(expected_out_size == actual_out_size, | ||
| "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, | ||
| " = ", expected_out_size, ", got ", actual_out_size); | ||
| } | ||
|
|
||
| auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); | ||
| bool grad = false; | ||
| bool accumulate = false; | ||
| bool use_split_accumulator = false; | ||
| auto bias_shape = std::vector<size_t>{has_bias ? n : 0}; | ||
| const int arch = cuda::sm_arch(); | ||
|
|
||
| if (arch < 100 && is_fp8_gemm) { | ||
| NVTE_CHECK(!lhs_is_trans && rhs_is_trans, | ||
| "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", | ||
| "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); | ||
| } |
There was a problem hiding this comment.
Multiple dead variables and unreachable block in GroupedGemmV2FFI
Several variables computed in this function are never consumed by the final nvte_grouped_gemm call (the V2 API takes grouped-tensor wrappers, not raw pointers or per-stream bookkeeping), and will generate -Wunused-variable compiler warnings:
lhs_ptr,rhs_ptr(lines 585–586) — raw pointers, never passed anywherelhs_sinv_ptr,rhs_sinv_ptr(lines 587–588) — sameout_ptr(line 618) — samenum_math_sm,grad,accumulate,use_split_accumulator,bias_shape(lines 680–684) — multi-stream bookkeeping that is unused in the V2 pathout_dtype_bytes(line 652) — never referenced
Additionally, the FP8 architecture check (lines 687–691) is unreachable: the NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, ...) on line 606 guarantees is_fp8_gemm is always false in this function. The whole block can be removed or marked with a // TODO: wire up when FP8 is added comment alongside the other FP8 placeholders.
These should be cleaned up to avoid misleading future contributors and to keep compiler warnings clean.
| //// LHS | ||
| NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; | ||
| lhs_is_trans = true; |
There was a problem hiding this comment.
Redundant lhs_is_trans = true assignment
The NVTE_CHECK(lhs_is_trans && !rhs_is_trans, ...) guard on line 707 already guarantees that lhs_is_trans is true when execution reaches line 717. The reassignment is dead code and may confuse readers into thinking the shape construction depends on overriding a possibly-false value.
| //// LHS | |
| NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; | |
| lhs_is_trans = true; | |
| //// LHS | |
| NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; | |
| // lhs_is_trans is already verified true by the NVTE_CHECK above. |
Description
Integrate new BF16 grouped GEMM from TE common/cuBLASLt that supports on-device group sizes without a D2H memcpy and stream sync. This grouped GEMM is faster and CUDA-graph safe.
Also fixes #2659
Type of change
Changes
make_ragged_dot_clsfor easy integration into existing models. This will be most useful when quantization is supported and storing recipe state is requiredChecklist: