Split Kernel Optimization for group_index_select_or_add_2d_kernel#142
Split Kernel Optimization for group_index_select_or_add_2d_kernel#142aryaman-gupta wants to merge 27 commits intomain_12162025_upstreamfrom
group_index_select_or_add_2d_kernel#142Conversation
…g on embedding dim size
…p_index_select_or_add_2d_kernel
…zed small embedding dims path
…isable optimized smallEmbD path
…_2d_kernel_small kernel
…data between forward and backward passes
…adds debug code" This reverts commit 4576e59.
…s in backward function" This reverts commit 39afe28.
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2206 Pull Request resolved: pytorch#5207 - Revert the changes made in D87922263 Reviewed By: cthi, atalman, huydhn Differential Revision: D88774663 fbshipit-source-id: ecc0486eb82564ebc31eac503c58a35600816548
…aman/group-index-optimizations
avbokovoy
left a comment
There was a problem hiding this comment.
The overall logic looks correct to me. Proposed several changes that will significantly reduce code duplication, diff size and probably will make potential CUDA integration easier (at least on GPU and invoke function sides)
There was a problem hiding this comment.
Discussed this offline. I believe that code duplication might be reduced by having unified kernel template for both Nvidia and ROCm paths, and using std::visit with platform specific std::variant as in https://github.com/ROCm/FBGEMM/pull/139/changes#diff-6f509196a8893b5345f5e615251ce85ea5f575b81c1e9136fff764a899d92562R329-R407. This way we won't need to duplicate invoke macro and compilation time/binary size will be aligned with expectations for each platform. It will also make adding new template parameters easier.
| // All columns are the same | ||
| member_id = warp_id / (warps_per_row * num_work_rows); | ||
| member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); | ||
| #ifdef USE_ROCM |
There was a problem hiding this comment.
Do we really need to wrap it in USE_ROCM if we already have a if constexpr (USE_SMALL_EMB_DIM)?
| #endif | ||
| } | ||
|
|
||
| #ifdef USE_ROCM |
There was a problem hiding this comment.
Same comment as in https://github.com/ROCm/FBGEMM/pull/142/changes#r2758719145
| int64_t start_row = member_warp_id * rows_per_warp; | ||
|
|
||
| // Since we are processing multiple rows within the warp, we need to | ||
| // map each lane to a specific row, in addition to the column | ||
| const auto local_row = (threadIdx.x * UNROLL_FACTOR) / | ||
| num_cols; // the row ID within the set of rows handled by this warp | ||
| const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; | ||
| const int64_t current_row = start_row + | ||
| local_row; // the actual row within the table processed by this lane | ||
|
|
||
| // local_row may be out of bounds for the last few lanes in the warp if | ||
| // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are | ||
| // within num_work_rows | ||
| int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp | ||
| int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; | ||
| int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane | ||
|
|
||
| // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] | ||
| // and we also need to confirm that we are within num_work_rows |
| const bool use_var_cols) { | ||
| const bool use_var_cols | ||
| #ifdef USE_ROCM | ||
| ,const bool use_small_emb_dim |
There was a problem hiding this comment.
We may force use_small_emb_dim=false on caller side in case of CUDA and keep API consistent
| TORCH_CHECK(saved_data_t_large.is_contiguous(), "Tensor saved_data_t_large must be contiguous."); | ||
| memcpy(saved_data_t_large.mutable_data_ptr<int64_t>(), saved_data_large, sizeof(saved_data_large)); | ||
|
|
||
| if(small.count > 0) { |
There was a problem hiding this comment.
We already have early exit condition in group_index_select_or_add_cuda in case group_size == 0, so this and the guard below are redundant
| #ifdef USE_ROCM | ||
| if (num_cols_ < cols_per_warp) { | ||
| // Optimization for Small Embedding: Pack multiple rows per warp | ||
|
|
||
| if(!first_small_table && num_cols_ != prev_num_cols_small) { | ||
| use_var_cols_small = true; | ||
| } | ||
| first_small_table = false; | ||
| prev_num_cols_small = num_cols_; | ||
| small.input_ptrs[small.count] = reinterpret_cast<int64_t>(input_contigs[i]->const_data_ptr()); | ||
| small.output_ptrs[small.count] = reinterpret_cast<int64_t>(output.mutable_data_ptr()); | ||
| small.indices_ptrs[small.count] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr()); | ||
| small.num_cols_group[small.count] = num_cols_; | ||
| small.warp_offsets_group[small.count] = small.total_warps; | ||
| small.total_warps += warps_needed; | ||
| small.count++; | ||
| } else { | ||
| // Standard Embedding: One or more warps per row | ||
|
|
||
| if(!first_large_table && num_cols_ != prev_num_cols_large) { | ||
| use_var_cols_large = true; | ||
| } | ||
| first_large_table = false; | ||
| prev_num_cols_large = num_cols_; | ||
|
|
||
| large.input_ptrs[large.count] = reinterpret_cast<int64_t>(input_contigs[i]->const_data_ptr()); | ||
| large.output_ptrs[large.count] = reinterpret_cast<int64_t>(output.mutable_data_ptr()); | ||
| large.indices_ptrs[large.count] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr()); | ||
| large.num_cols_group[large.count] = num_cols_; | ||
| large.warp_offsets_group[large.count] = large.total_warps; | ||
| large.total_warps += warps_needed; | ||
| large.count++; | ||
| } |
There was a problem hiding this comment.
It seems like in this block of code we are doing this same thing for both if and else branches. Maybe it's a good idea to choose the groups (large/small) based on ternary condition and the apply the operation on chose groups. Something like:
auto& dst_groups = (num_cols_ < cols_per_warp) ? small : large;
auto& first_table = (num_cols_ < cols_per_warp) ? first_small_table : first_large_table;
auto& prev_num_cols = (num_cols_ < cols_per_warp) ? prev_num_cols_small : prev_num_cols_large;
// Proceed using defined variables without if/else| if(large.count > 0) { | ||
| args_tensor_large = args_tensor_large.to( | ||
| first_input.device(), | ||
| /*non_blocking=*/true); | ||
|
|
||
| // Offset raw ptrs in GPU memory | ||
| offset_args( | ||
| &large.input_ptrs, | ||
| &large.output_ptrs, | ||
| &large.indices_ptrs, | ||
| &large.warp_offsets_group, | ||
| &large.num_cols_group, | ||
| reinterpret_cast<int64_t*>(args_tensor_large.mutable_data_ptr()), | ||
| args_ptrs_offsets); | ||
| } |
There was a problem hiding this comment.
I believe that code blocks like that might be extracted to lambdas or utility functions
| const auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); | ||
| const auto num_cols_ = grad_output_reshaped.size(1); | ||
| if(num_cols_ < cols_per_warp) { | ||
| indices_ptrs_small[idx_small] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr()); | ||
| idx_small++; | ||
| } else { | ||
| indices_ptrs_large[idx_large] = reinterpret_cast<int64_t>(index_contigs[i]->const_data_ptr()); | ||
| idx_large++; | ||
| } |
There was a problem hiding this comment.
We do such reshape->pointer assignment several time across the code. Let's extract it
| #ifdef USE_ROCM | ||
| // Transfer grad output pointers to GPU | ||
| args_tensor_small = args_tensor_small.to(first_indices.device(), /*non_blocking=*/true); | ||
| args_tensor_large = args_tensor_large.to(first_indices.device(), /*non_blocking=*/true); | ||
|
|
||
| if(count_small > 0) { | ||
| group_index_select_or_add_cuda( | ||
| args_tensor_small.const_data_ptr<int64_t>(), | ||
| args_tensor_small.const_data_ptr<int64_t>() + count_small, | ||
| args_tensor_small.const_data_ptr<int64_t>() + 2 * count_small, | ||
| warp_offsets_group_small, | ||
| num_cols_group_small, | ||
| fwd_input.scalar_type(), | ||
| first_indices.scalar_type(), | ||
| fwd_input.device().index(), | ||
| num_input_rows, | ||
| total_num_warps_small, | ||
| count_small, | ||
| /*use_index_select=*/false, | ||
| use_var_cols_small, | ||
| /*use_small_emb_dim=*/true); | ||
| } | ||
|
|
||
| if(count_large > 0) { | ||
| group_index_select_or_add_cuda( | ||
| args_tensor_large.data_ptr<int64_t>(), | ||
| args_tensor_large.data_ptr<int64_t>() + count_large, | ||
| args_tensor_large.data_ptr<int64_t>() + 2 * count_large, | ||
| warp_offsets_group_large, | ||
| num_cols_group_large, | ||
| fwd_input.scalar_type(), | ||
| first_indices.scalar_type(), | ||
| fwd_input.device().index(), | ||
| num_input_rows, | ||
| total_num_warps_large, | ||
| count_large, | ||
| /*use_index_select=*/false, | ||
| use_var_cols_large, | ||
| /*use_small_emb_dim=*/false); | ||
| } | ||
| #else |
There was a problem hiding this comment.
Same comment as in https://github.com/ROCm/FBGEMM/pull/142/changes#r2759023013
Overview
Optimizes
group_index_selectoperations on ROCm by splitting tables into small and large embedding groups and launching specialized kernels for each.Implementation
Introduces compile-time template parameter
USE_SMALL_EMB_DIMthat splits kernel execution:num_cols < cols_per_warp): Pack multiple rows per warp for better occupancynum_cols >= cols_per_warp): Standard one-or-more warps per rowTables are separated into two groups during host-side processing, then launched with two separate kernel invocations.
Why It Works
Independent
use_var_colstracking: Each group (small/large) maintains its ownuse_var_colsflag. If all small embeddings have identical dimensions, the small kernel skips variable column logic entirely, even when large embeddings vary.Compact kernel generation: Compile-time specialization eliminates runtime branching between small/large paths, producing tighter kernels.
Files Modified
Kernel Implementation (
sparse_group_index.cu)USE_SMALL_EMB_DIMtemplate parameter to thegroup_index_select_or_add_2d_kernelfor compile-time kernel specializationHost GPU Code (
sparse_ops_gpu.cpp)args_tensor_small/largeandsaved_data_small/large, launch separate kernelsgroup_size + 4tensors on ROCm (2 args_tensors + 2 saved_data) vsgroup_size + 2on CUDACPU Wrapper (
sparse_ops_cpu.cpp)group_index_selectremains unmodified and returns 2 elements, while the ROCm GPU implementation now returns 4.Python Interface (
sparse_ops.py)torch.version.hip is not None) to allocate appropriate number of tensorsHeader (
sparse_ops.h)use_small_emb_dimparameter togroup_index_select_or_add_cudafunction signature to control kernel specialization (conditionally compiled for ROCm only)