Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7c46453
Grouped GEMM
jberchtold-nvidia Feb 13, 2026
5a96845
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
49b45fa
disable cuda-graph for GMM
jberchtold-nvidia Feb 24, 2026
593a790
proper workspace size
jberchtold-nvidia Feb 24, 2026
ae34461
remove duplicate workspace size logic in Python gemm.py
jberchtold-nvidia Feb 24, 2026
7e99c64
use group_sizes as int32 and handle int64 and offsets inside FFI to a…
jberchtold-nvidia Feb 24, 2026
a661e9e
restore previous non-cuda-graphable grouped GEMM FFI and move new ver…
jberchtold-nvidia Feb 24, 2026
6fd7f16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
0d5837d
cleanup and lint fixes
jberchtold-nvidia Feb 24, 2026
d3ee0fc
re-add cublas alignment checks
jberchtold-nvidia Feb 24, 2026
6440648
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
661a829
fix symbol export when building with older cublas
jberchtold-nvidia Feb 25, 2026
7d15c4c
Merge branch 'main' into gmm
jberchtold-nvidia Feb 25, 2026
bd5e6fb
Fix backend selection depending on whether TE was compiled with the
jberchtold-nvidia Feb 25, 2026
60d5c42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2026
70d8f78
Merge branch 'main' into gmm
jberchtold-nvidia Feb 25, 2026
5b18801
Consolidate grouped GEMM primitives
jberchtold-nvidia Mar 3, 2026
59521e8
fixes
jberchtold-nvidia Mar 3, 2026
0f5d95a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
8361484
Merge branch 'main' into gmm
jberchtold-nvidia Mar 3, 2026
3ce29cb
Fixes
jberchtold-nvidia Mar 4, 2026
5927203
Update C++ grouped GEMM tests to address row-major/col-major bugfix
jberchtold-nvidia Mar 4, 2026
234eda1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
62ade26
Update transformer_engine/jax/cpp_extensions/gemm.py
jberchtold-nvidia Mar 4, 2026
855c3be
Update transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
jberchtold-nvidia Mar 4, 2026
c20eea6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
54e2d18
Address review comments
jberchtold-nvidia Mar 4, 2026
b09c378
Fix GMM runtime errors with nvte_set_grouped_tensor_param in gemm.cpp
jberchtold-nvidia Mar 4, 2026
0a86b73
Rename CudaGraphable to V2 and remove unnecessary FFI attributes
jberchtold-nvidia Mar 4, 2026
b9a8155
Lint
jberchtold-nvidia Mar 4, 2026
505feac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ void run_grouped_gemm_case(const TestParams& params) {

for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{N, K}
: std::vector<size_t>{K, N};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, M}
: std::vector<size_t>{M, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
Expand Down
73 changes: 61 additions & 12 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,29 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle,
return heuristicResult.algo;
}

// Device helper: compute the element offset for tensor `idx` given shape metadata.
// Three cases:
// 1. Explicit per-tensor offset array provided → use it directly.
// 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products.
// 3. Fully uniform shapes → idx * uniform_first * uniform_last.
__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta,
size_t idx) {
if (meta.offsets) {
return meta.offsets[idx];
} else if (meta.first_dims != nullptr || meta.last_dims != nullptr) {
// offset[i] = sum_{j < i} (first_dims[j] * last_dims[j])
int64_t cumsum = 0;
for (size_t i = 0; i < idx; i++) {
int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first;
int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last;
cumsum += f * l;
}
return cumsum;
} else {
return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last;
}
}
Comment on lines +448 to +464
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O(n²) complexity in parallel kernel. Each of the n threads calls this function with a different idx, and for case 2 (per-tensor dims without explicit offsets), thread idx performs a sequential loop from 0 to idx-1. This creates O(1 + 2 + ... + n) = O(n²) total work across all threads.

For large numbers of groups, consider either:

  1. Computing offsets once on CPU and passing them explicitly
  2. Using a parallel prefix sum (scan) to compute cumulative offsets
  3. Documenting this limitation if group counts are expected to be small


// Single kernel that sets up all GEMM parameters.
// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
Expand All @@ -464,15 +487,11 @@ __global__ void setup_grouped_gemm_kernel(
int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first;
int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last;

// Compute offsets (from array or compute from uniform dims)
int64_t a_offset =
A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last);
int64_t b_offset =
B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last);
int64_t c_offset =
C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last);
int64_t d_offset =
D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last);
// Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform)
int64_t a_offset = compute_grouped_tensor_offset(A_meta, idx);
int64_t b_offset = compute_grouped_tensor_offset(B_meta, idx);
int64_t c_offset = compute_grouped_tensor_offset(C_meta, idx);
int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx);

// Compute data pointers
A_ptrs[idx] = a_base + a_offset * a_elem_size;
Expand All @@ -487,9 +506,8 @@ __global__ void setup_grouped_gemm_kernel(
a_cols[idx] = static_cast<int>(a_first);
b_rows[idx] = static_cast<int>(b_last);
b_cols[idx] = static_cast<int>(b_first);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows[idx] = static_cast<int>(d_first);
d_cols[idx] = static_cast<int>(d_last);
d_rows[idx] = static_cast<int>(d_last);
d_cols[idx] = static_cast<int>(d_first);

// Fill alpha/beta pointers (per-matrix)
alpha_ptrs[idx] = alpha_ptr + idx;
Expand Down Expand Up @@ -535,6 +553,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {

} // namespace

size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) {
NVTE_API_CALL(nvte_get_grouped_gemm_setup_workspace_size);
return grouped_gemm_setup_workspace_size(num_tensors);
}

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
Expand Down Expand Up @@ -642,4 +665,30 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
}

size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) {
NVTE_ERROR(
"nvte_get_grouped_gemm_setup_workspace_size requires cuBLAS 13.2+, but compile-time cuBLAS "
"version is ",
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
return 0;
}

#endif // CUBLAS_VERSION >= 130200

namespace {

__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]);
}

} // namespace

void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_int32_to_int64);
if (n == 0) return;
const int threads = 256;
const int blocks = static_cast<int>((n + threads - 1) / threads);
convert_int32_to_int64_kernel<<<blocks, threads, 0, stream>>>(src, dst, n);
NVTE_CHECK_CUDA(cudaGetLastError());
}
25 changes: 25 additions & 0 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,31 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
/*! \brief Return the required size in bytes for the setup workspace of grouped GEMM.
*
* The setup workspace stores pointer arrays and per-matrix dimension arrays used
* by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs)
* in the group and is independent of matrix dimensions.
*
* Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm.
*
* \param[in] num_tensors Number of tensors (GEMMs) in the group.
* \return Required size in bytes for workspace_setup.
*/
size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors);

/*! \brief Convert a device array of int32 values to int64 values.
*
* Useful for preparing group_sizes for nvte_grouped_gemm when the caller
* holds int32 sizes and needs int64 values on the device.
*
* \param[in] src Device pointer to source int32 array.
* \param[out] dst Device pointer to destination int64 array.
* \param[in] n Number of elements.
* \param[in] stream CUDA stream.
*/
void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream);

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
Expand Down
Loading
Loading