Skip to content

Split Kernel Optimization for group_index_select_or_add_2d_kernel#142

Open
aryaman-gupta wants to merge 27 commits intomain_12162025_upstreamfrom
aryaman/group-index-optimizations
Open

Split Kernel Optimization for group_index_select_or_add_2d_kernel#142
aryaman-gupta wants to merge 27 commits intomain_12162025_upstreamfrom
aryaman/group-index-optimizations

Conversation

@aryaman-gupta
Copy link

Overview

Optimizes group_index_select operations on ROCm by splitting tables into small and large embedding groups and launching specialized kernels for each.

Implementation

Introduces compile-time template parameter USE_SMALL_EMB_DIM that splits kernel execution:

  • Small embeddings (num_cols < cols_per_warp): Pack multiple rows per warp for better occupancy
  • Large embeddings (num_cols >= cols_per_warp): Standard one-or-more warps per row

Tables are separated into two groups during host-side processing, then launched with two separate kernel invocations.

Why It Works

  1. Independent use_var_cols tracking: Each group (small/large) maintains its own use_var_cols flag. If all small embeddings have identical dimensions, the small kernel skips variable column logic entirely, even when large embeddings vary.

  2. Compact kernel generation: Compile-time specialization eliminates runtime branching between small/large paths, producing tighter kernels.

Files Modified

Kernel Implementation (sparse_group_index.cu)

  • Added USE_SMALL_EMB_DIM template parameter to the group_index_select_or_add_2d_kernel for compile-time kernel specialization

Host GPU Code (sparse_ops_gpu.cpp)

  • Forward pass: Split tables into small/large groups, allocate dual args_tensor_small/large and saved_data_small/large, launch separate kernels
  • Backward pass: Unpack dual saved data, split gradients into small/large groups, launch separate kernels
  • Returns group_size + 4 tensors on ROCm (2 args_tensors + 2 saved_data) vs group_size + 2 on CUDA

CPU Wrapper (sparse_ops_cpu.cpp)

  • Updated tensor count validation to handle variable returns. The CPU implementation of group_index_select remains unmodified and returns 2 elements, while the ROCm GPU implementation now returns 4.

Python Interface (sparse_ops.py)

  • Added runtime detection (torch.version.hip is not None) to allocate appropriate number of tensors

Header (sparse_ops.h)

  • Added use_small_emb_dim parameter to group_index_select_or_add_cuda function signature to control kernel specialization (conditionally compiled for ROCm only)

aryaman-gupta and others added 27 commits December 12, 2025 15:09
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2206

Pull Request resolved: pytorch#5207

- Revert the changes made in D87922263

Reviewed By: cthi, atalman, huydhn

Differential Revision: D88774663

fbshipit-source-id: ecc0486eb82564ebc31eac503c58a35600816548
@aryaman-gupta aryaman-gupta added the enhancement New feature or request label Feb 2, 2026
Copy link

@avbokovoy avbokovoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall logic looks correct to me. Proposed several changes that will significantly reduce code duplication, diff size and probably will make potential CUDA integration easier (at least on GPU and invoke function sides)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this offline. I believe that code duplication might be reduced by having unified kernel template for both Nvidia and ROCm paths, and using std::visit with platform specific std::variant as in https://github.com/ROCm/FBGEMM/pull/139/changes#diff-6f509196a8893b5345f5e615251ce85ea5f575b81c1e9136fff764a899d92562R329-R407. This way we won't need to duplicate invoke macro and compilation time/binary size will be aligned with expectations for each platform. It will also make adding new template parameters easier.

// All columns are the same
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
#ifdef USE_ROCM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to wrap it in USE_ROCM if we already have a if constexpr (USE_SMALL_EMB_DIM)?

#endif
}

#ifdef USE_ROCM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +122 to +131
int64_t start_row = member_warp_id * rows_per_warp;

// Since we are processing multiple rows within the warp, we need to
// map each lane to a specific row, in addition to the column
const auto local_row = (threadIdx.x * UNROLL_FACTOR) /
num_cols; // the row ID within the set of rows handled by this warp
const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols;
const int64_t current_row = start_row +
local_row; // the actual row within the table processed by this lane

// local_row may be out of bounds for the last few lanes in the warp if
// [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are
// within num_work_rows
int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp
int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols;
int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane

// local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0]
// and we also need to confirm that we are within num_work_rows

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we removed const qualifier?

const bool use_var_cols) {
const bool use_var_cols
#ifdef USE_ROCM
,const bool use_small_emb_dim

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may force use_small_emb_dim=false on caller side in case of CUDA and keep API consistent

TORCH_CHECK(saved_data_t_large.is_contiguous(), "Tensor saved_data_t_large must be contiguous.");
memcpy(saved_data_t_large.mutable_data_ptr<int64_t>(), saved_data_large, sizeof(saved_data_large));

if(small.count > 0) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have early exit condition in group_index_select_or_add_cuda in case group_size == 0, so this and the guard below are redundant

Comment on lines +421 to +453
#ifdef USE_ROCM
if (num_cols_ < cols_per_warp) {
// Optimization for Small Embedding: Pack multiple rows per warp

if(!first_small_table && num_cols_ != prev_num_cols_small) {
use_var_cols_small = true;
}
first_small_table = false;
prev_num_cols_small = num_cols_;
small.input_ptrs[small.count] = reinterpret_cast<int64_t>(input_contigs[i]->const_data_ptr());
small.output_ptrs[small.count] = reinterpret_cast<int64_t>(output.mutable_data_ptr());
small.indices_ptrs[small.count] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr());
small.num_cols_group[small.count] = num_cols_;
small.warp_offsets_group[small.count] = small.total_warps;
small.total_warps += warps_needed;
small.count++;
} else {
// Standard Embedding: One or more warps per row

if(!first_large_table && num_cols_ != prev_num_cols_large) {
use_var_cols_large = true;
}
first_large_table = false;
prev_num_cols_large = num_cols_;

large.input_ptrs[large.count] = reinterpret_cast<int64_t>(input_contigs[i]->const_data_ptr());
large.output_ptrs[large.count] = reinterpret_cast<int64_t>(output.mutable_data_ptr());
large.indices_ptrs[large.count] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr());
large.num_cols_group[large.count] = num_cols_;
large.warp_offsets_group[large.count] = large.total_warps;
large.total_warps += warps_needed;
large.count++;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like in this block of code we are doing this same thing for both if and else branches. Maybe it's a good idea to choose the groups (large/small) based on ternary condition and the apply the operation on chose groups. Something like:

auto& dst_groups = (num_cols_ < cols_per_warp) ? small : large;
auto& first_table = (num_cols_ < cols_per_warp) ? first_small_table : first_large_table;
auto& prev_num_cols = (num_cols_ < cols_per_warp) ? prev_num_cols_small : prev_num_cols_large;
// Proceed using defined variables without if/else

Comment on lines +498 to +512
if(large.count > 0) {
args_tensor_large = args_tensor_large.to(
first_input.device(),
/*non_blocking=*/true);

// Offset raw ptrs in GPU memory
offset_args(
&large.input_ptrs,
&large.output_ptrs,
&large.indices_ptrs,
&large.warp_offsets_group,
&large.num_cols_group,
reinterpret_cast<int64_t*>(args_tensor_large.mutable_data_ptr()),
args_ptrs_offsets);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that code blocks like that might be extracted to lambdas or utility functions

Comment on lines +928 to +936
const auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1});
const auto num_cols_ = grad_output_reshaped.size(1);
if(num_cols_ < cols_per_warp) {
indices_ptrs_small[idx_small] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr());
idx_small++;
} else {
indices_ptrs_large[idx_large] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr());
idx_large++;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do such reshape->pointer assignment several time across the code. Let's extract it

Comment on lines +943 to +983
#ifdef USE_ROCM
// Transfer grad output pointers to GPU
args_tensor_small = args_tensor_small.to(first_indices.device(), /*non_blocking=*/true);
args_tensor_large = args_tensor_large.to(first_indices.device(), /*non_blocking=*/true);

if(count_small > 0) {
group_index_select_or_add_cuda(
args_tensor_small.const_data_ptr<int64_t>(),
args_tensor_small.const_data_ptr<int64_t>() + count_small,
args_tensor_small.const_data_ptr<int64_t>() + 2 * count_small,
warp_offsets_group_small,
num_cols_group_small,
fwd_input.scalar_type(),
first_indices.scalar_type(),
fwd_input.device().index(),
num_input_rows,
total_num_warps_small,
count_small,
/*use_index_select=*/false,
use_var_cols_small,
/*use_small_emb_dim=*/true);
}

if(count_large > 0) {
group_index_select_or_add_cuda(
args_tensor_large.data_ptr<int64_t>(),
args_tensor_large.data_ptr<int64_t>() + count_large,
args_tensor_large.data_ptr<int64_t>() + 2 * count_large,
warp_offsets_group_large,
num_cols_group_large,
fwd_input.scalar_type(),
first_indices.scalar_type(),
fwd_input.device().index(),
num_input_rows,
total_num_warps_large,
count_large,
/*use_index_select=*/false,
use_var_cols_large,
/*use_small_emb_dim=*/false);
}
#else

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants