From 85caa29d934af46192aeb590fed2d5371754e3ce Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 12 Dec 2025 15:09:14 +0000 Subject: [PATCH 01/25] adds optimized path for small dimension sizes to group_index_select_or_add_2d_kernel --- .../src/sparse_ops/sparse_group_index.cu | 76 ++++++++++++++----- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 14 +++- 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 96c57cde68..a0584b23de 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -83,25 +83,65 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + + if (num_cols < COLS_PER_WARP) { + // Optimized path for small embedding dimensions + // Each warp processes 'rows_per_warp' rows + int rows_per_warp = COLS_PER_WARP / num_cols; + int64_t start_row = member_warp_id * rows_per_warp; + + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp + // if [COLS_PER_WARP % num_cols != 0] + // TODO: check if current_row < num_work_rows is necessary + if (local_row < rows_per_warp && current_row < num_work_rows) { + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + index_t idx = indices[current_row]; + + scalar_t* input_base = reinterpret_cast(input_ptrs[member_id]); + scalar_t* output_base = reinterpret_cast(output_ptrs[member_id]); + +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output_base[current_row * num_cols + col] = + LDG(&input_base[idx * num_cols + col]); + } else { + gpuAtomicAddNoReturn( + &output_base[idx * num_cols + col], + input_base[current_row * num_cols + col]); + } + } + } + } else { + // Large embedding dimensions use >= 1 warp per row + + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; + #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); + } } } } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 9e8587b8d1..bdd0f13652 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -303,7 +303,17 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Number of columns can be different auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + int64_t warps_needed; + if (num_cols_ < cols_per_warp) { + // Optimization: Pack multiple rows into one warp + int rows_per_warp = cols_per_warp / num_cols_; + warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; + } else { + // Standard: One or more warps per row + int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + warps_needed = warps_per_row * num_output_rows_; + } if (num_cols != num_cols_) { use_var_cols = true; @@ -329,7 +339,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offsets_group[i] = warp_offset; num_cols_group[i] = num_cols_; - warp_offset += warps_per_row * num_output_rows; + warp_offset += warps_needed; } // Store the last offset From e4f1dba367ccfb2359e2e28c021762b024972678 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 16 Dec 2025 14:36:05 +0000 Subject: [PATCH 02/25] group_index_select_or_add_2d_kernel: splits into two kernels depending on embedding dim size --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 9 + fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 15 + .../src/sparse_ops/sparse_group_index.cu | 187 +++++++++ fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 362 +++++++++++++----- 4 files changed, 473 insertions(+), 100 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 4df2d2bfe3..1567b241b8 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -947,7 +947,16 @@ def group_index_select_dim0_gpu_impl_abstract( ) ) + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) + ) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + return ret diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index efebf3ac02..5ce87e7b0a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1088,6 +1088,21 @@ void group_index_select_or_add_cuda( const bool use_index_select, const bool use_var_cols); +void group_index_select_or_add_cuda_smallEmbD( + const int64_t* input_ptrs, + const int64_t* output_ptrs, + const int64_t* indices_ptrs, + const int64_t* warp_offsets_group, + const int32_t* num_cols_group, + const c10::ScalarType& input_scalar_type, + const c10::ScalarType& indices_scalar_type, + const c10::DeviceIndex& device, + const int num_work_rows, + const int64_t total_num_warps, + const int group_size, + const bool use_index_select, + const bool use_var_cols); + int get_group_index_select_cols_per_warp(); std::vector jagged_index_select_2d( diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index a0584b23de..ddbaa89a27 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -33,6 +33,120 @@ int get_group_index_select_cols_per_warp() { return GROUP_INDEX_SELECT_COLS_PER_WARP; } +template < + typename index_t, + typename scalar_t, + bool USE_INDEX_SELECT, + bool USE_VAR_COLS, + int UNROLL_FACTOR, + int COLS_PER_WARP, + int LOG_COLS_PER_WARP> +__global__ +__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( + const int64_t* input_ptrs, + const int64_t* output_ptrs, + const int64_t* indices_ptrs, + const int64_t* warp_offsets_group, + const int32_t* num_cols_group, + const int64_t num_work_rows, // number of rows to work on per member + const int64_t group_size) { + const auto total_num_warps = warp_offsets_group[group_size]; + int32_t num_cols = 0; + int32_t warps_per_row = 0; + + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } + + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; + } else { + // All columns are the same + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + + if (num_cols < COLS_PER_WARP) { + // Optimized path for small embedding dimensions + // Each warp processes 'rows_per_warp' rows + int rows_per_warp = COLS_PER_WARP / num_cols; + int64_t start_row = member_warp_id * rows_per_warp; + + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp + // if [COLS_PER_WARP % num_cols != 0] + // TODO: check if current_row < num_work_rows is necessary + if (local_row < rows_per_warp && current_row < num_work_rows) { + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + index_t idx = indices[current_row]; + + scalar_t* input_base = reinterpret_cast(input_ptrs[member_id]); + scalar_t* output_base = reinterpret_cast(output_ptrs[member_id]); + +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output_base[current_row * num_cols + col] = + LDG(&input_base[idx * num_cols + col]); + } else { + gpuAtomicAddNoReturn( + &output_base[idx * num_cols + col], + input_base[current_row * num_cols + col]); + } + } + } + } else { + // Large embedding dimensions use >= 1 warp per row + + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; + +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); + } + } + } + } +} + template < typename index_t, typename scalar_t, @@ -221,4 +335,77 @@ DLL_PUBLIC void group_index_select_or_add_cuda( #undef INVOKE_GROUP_INDEX_SELECT_OR_ADD } +DLL_PUBLIC void group_index_select_or_add_cuda_smallEmbD( + const int64_t* input_ptrs, + const int64_t* output_ptrs, + const int64_t* indices_ptrs, + const int64_t* warp_offsets_group, + const int32_t* num_cols_group, + const c10::ScalarType& input_scalar_type, + const c10::ScalarType& indices_scalar_type, + const c10::DeviceIndex& device, + const int num_work_rows, + const int64_t total_num_warps, + const int group_size, + const bool use_index_select, + const bool use_var_cols) { + if (group_size == 0) { + return; + } + + at::cuda::OptionalCUDAGuard device_guard(device); + + // Partition work based on num_work_rows + uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; + uint32_t max_grid_size = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; + uint32_t grid_size = std::min( + cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), + max_grid_size); + dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); + + // Launcher Macro for Small Kernel +#define INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(USE_INDEX_SELECT, USE_VAR_COLS) \ + FBGEMM_LAUNCH_KERNEL( \ + (group_index_select_or_add_2d_kernel_small< \ + index_t, \ + scalar_t, \ + USE_INDEX_SELECT, \ + USE_VAR_COLS, \ + GROUP_INDEX_SELECT_COLS_PER_WARP>), \ + grid_size, \ + block_size, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + input_ptrs, \ + output_ptrs, \ + indices_ptrs, \ + warp_offsets_group, \ + num_cols_group, \ + num_work_rows, \ + group_size) + + AT_DISPATCH_INDEX_TYPES( + indices_scalar_type, "group_index_select_2d_small_wrapper_1", [&] { + FBGEMM_DISPATCH_FLOATING_TYPES( + input_scalar_type, "group_index_select_2d_small_wrapper_2", [&] { + if (use_index_select) { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(true, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(true, false); + } + } else { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(false, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(false, false); + } + } + }); + }); + +#undef INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index bdd0f13652..8f0fa06b87 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -232,28 +232,48 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( args_ptrs_offsets[NUM_ARGS] = offset; // Allocate memory for GroupIndexSelectArgs - at::Tensor args_tensor = at::empty( + at::Tensor args_tensor_small = at::empty( {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); + at::Tensor args_tensor_large = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - // Initialize raw pointers to point to Tensor args_tensor - int64_t* input_ptrs = nullptr; - int64_t* output_ptrs = nullptr; - int64_t* indices_ptrs = nullptr; - int64_t* warp_offsets_group = nullptr; - int32_t* num_cols_group = nullptr; + // Ensure that args_tensors are contiguous + TORCH_CHECK(args_tensor_small.is_contiguous()); + TORCH_CHECK(args_tensor_large.is_contiguous()); + + // defining a struct that will maintain the arguments required by + // the GPU kernel + struct SplitArgs { + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + int64_t total_warps = 0; + int64_t count = 0; + } small, large; // small and large structs to hold args for small and large + // tables respectively // Offset host pointers offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), + &small.input_ptrs, + &small.output_ptrs, + &small.indices_ptrs, + &small.warp_offsets_group, + &small.num_cols_group, + reinterpret_cast(args_tensor_small.data_ptr()), + args_ptrs_offsets); + + offset_args( + &large.input_ptrs, + &large.output_ptrs, + &large.indices_ptrs, + &large.warp_offsets_group, + &large.num_cols_group, + reinterpret_cast(args_tensor_large.data_ptr()), args_ptrs_offsets); auto& first_input = input_group[0]; @@ -315,6 +335,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warps_needed = warps_per_row * num_output_rows_; } + // TODO: maintain [use_var_cols] separately for small emb dims if (num_cols != num_cols_) { use_var_cols = true; } @@ -332,83 +353,153 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.push_back(input.expect_contiguous()); index_contigs.push_back(indices.expect_contiguous()); - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; - - warp_offset += warps_needed; + if (num_cols_ < cols_per_warp) { + // Optimization for Small Embedding: Pack multiple rows per warp + small.input_ptrs[small.count] = reinterpret_cast(input_contigs[i]->data_ptr()); + small.output_ptrs[small.count] = reinterpret_cast(output.data_ptr()); + small.indices_ptrs[small.count] = reinterpret_cast(index_contigs[i]->data_ptr()); + small.num_cols_group[small.count] = num_cols_; + small.warp_offsets_group[small.count] = small.total_warps; + small.total_warps += warps_needed; + small.count++; + } else { + // Standard Embedding: One or more warps per row + large.input_ptrs[large.count] = reinterpret_cast(input_contigs[i]->data_ptr()); + large.output_ptrs[large.count] = reinterpret_cast(output.data_ptr()); + large.indices_ptrs[large.count] = reinterpret_cast(index_contigs[i]->data_ptr()); + large.num_cols_group[large.count] = num_cols_; + large.warp_offsets_group[large.count] = large.total_warps; + large.total_warps += warps_needed; + large.count++; + } } // Store the last offset - warp_offsets_group[group_size] = warp_offset; + if (small.count > 0) { + small.warp_offsets_group[small.count] = small.total_warps; + } + if (large.count > 0) { + large.warp_offsets_group[large.count] = large.total_warps; + } - // Transfer args tensor to GPU - args_tensor = args_tensor.to( - first_input.device(), - /*non_blocking=*/true); + // Transfer args tensors to GPU + if(small.count > 0) { + args_tensor_small = args_tensor_small.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &small.input_ptrs, + &small.output_ptrs, + &small.indices_ptrs, + &small.warp_offsets_group, + &small.num_cols_group, + reinterpret_cast(args_tensor_small.data_ptr()), + args_ptrs_offsets); + } - // Offset raw ptrs in GPU memory - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); + if(large.count > 0) { + args_tensor_large = args_tensor_large.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &large.input_ptrs, + &large.output_ptrs, + &large.indices_ptrs, + &large.warp_offsets_group, + &large.num_cols_group, + reinterpret_cast(args_tensor_large.data_ptr()), + args_ptrs_offsets); + } + + int64_t saved_data_small[] = { + static_cast(small.count), + use_var_cols, + reinterpret_cast(small.warp_offsets_group), + reinterpret_cast(small.num_cols_group), + small.total_warps, + }; + auto saved_data_t_small = at::empty( + {sizeof(saved_data_small) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_t_small.is_contiguous()); + memcpy(saved_data_t_small.data_ptr(), saved_data_small, sizeof(saved_data_small)); - int64_t saved_data[] = { - static_cast(group_size), + int64_t saved_data_large[] = { + static_cast(large.count), use_var_cols, - reinterpret_cast(warp_offsets_group), - reinterpret_cast(num_cols_group), - warp_offset, + reinterpret_cast(large.warp_offsets_group), + reinterpret_cast(large.num_cols_group), + large.total_warps, }; - auto saved_data_t = at::empty( - {sizeof(saved_data) / sizeof(int64_t)}, + auto saved_data_t_large = at::empty( + {sizeof(saved_data_large) / sizeof(int64_t)}, at::TensorOptions().dtype(at::kLong)); - TORCH_CHECK(saved_data_t.is_contiguous()); - memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); - - group_index_select_or_add_cuda( - input_ptrs, - output_ptrs, - indices_ptrs, - warp_offsets_group, - num_cols_group, - first_input.scalar_type(), - first_indices.scalar_type(), - first_input.device().index(), - num_output_rows, - /*total_num_warps=*/warp_offset, - group_size, - /*use_index_select=*/true, - use_var_cols); + TORCH_CHECK(saved_data_t_large.is_contiguous()); + memcpy(saved_data_t_large.data_ptr(), saved_data_large, sizeof(saved_data_large)); + + if(small.count > 0) { + group_index_select_or_add_cuda_smallEmbD( + small.input_ptrs, + small.output_ptrs, + small.indices_ptrs, + small.warp_offsets_group, + small.num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/small.total_warps, + small.count, + /*use_index_select=*/true, + use_var_cols); + } + + if(large.count > 0) { + group_index_select_or_add_cuda( + large.input_ptrs, + large.output_ptrs, + large.indices_ptrs, + large.warp_offsets_group, + large.num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/large.total_warps, + large.count, + /*use_index_select=*/true, + use_var_cols); + } - output_group.push_back(args_tensor); - output_group.push_back(saved_data_t); + output_group.push_back(args_tensor_small); + output_group.push_back(args_tensor_large); + output_group.push_back(saved_data_t_small); + output_group.push_back(saved_data_t_large); // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) + // (group_size outputs, 2 args_tensor2, 2 saved_data) return output_group; } static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); + TORCH_CHECK(all_inputs.size() > 4); // all_input size = group_size * 2 (from grads, indices) - // + 1 args_tensor + 1 saved_data + 1 first input - const int64_t group_size = (all_inputs.size() - 3) / 2; + // + 2 args_tensor2 + 2 saved_data + 1 first input + const int64_t group_size = (all_inputs.size() - 5) / 2; - const Tensor& fwd_input = all_inputs[2 * group_size + 2]; - const int64_t output_dim = fwd_input.dim(); - const Tensor& saved_data = all_inputs[2 * group_size + 1]; + const Tensor& fwd_input = all_inputs[2 * group_size + 4]; + const Tensor& saved_data_small = all_inputs[2 * group_size + 2]; + const Tensor& saved_data_large = all_inputs[2 * group_size + 3]; + const Tensor& first_indices = all_inputs[group_size]; + const int64_t output_dim = fwd_input.dim(); auto grad_output_group = std::vector( all_inputs.cbegin(), all_inputs.cbegin() + group_size); @@ -422,15 +513,23 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); // Retrieve saved data - TORCH_CHECK(saved_data.device() == at::kCPU); - TORCH_CHECK(saved_data.is_contiguous()); - int64_t* saved_data_ptr = saved_data.data_ptr(); - // Check that the size is the same - TORCH_CHECK(saved_data_ptr[0] == group_size); - const bool use_var_cols = saved_data_ptr[1]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; + TORCH_CHECK(saved_data_small.device() == at::kCPU); + TORCH_CHECK(saved_data_small.is_contiguous()); + int64_t* saved_data_small_ptr = saved_data_small.data_ptr(); + auto count_small = saved_data_small_ptr[0]; + const bool use_var_cols = saved_data_small_ptr[1]; + int64_t* warp_offsets_group_small = reinterpret_cast(saved_data_small_ptr[2]); + int32_t* num_cols_group_small = reinterpret_cast(saved_data_small_ptr[3]); + int64_t total_num_warps_small = saved_data_small_ptr[4]; + + TORCH_CHECK(saved_data_large.device() == at::kCPU); + TORCH_CHECK(saved_data_large.is_contiguous()); + int64_t* saved_data_large_ptr = saved_data_large.data_ptr(); + auto count_large = saved_data_large_ptr[0]; + const bool use_var_cols_large = saved_data_large_ptr[1]; + int64_t* warp_offsets_group_large = reinterpret_cast(saved_data_large_ptr[2]); + int32_t* num_cols_group_large = reinterpret_cast(saved_data_large_ptr[3]); + int64_t total_num_warps_large = saved_data_large_ptr[4]; // We checked in forward that all output rows are the same for all member // in the group @@ -453,15 +552,25 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } - // Allocate Tensor for ptrs of grad output and input, and indices - Tensor args_tensor = at::empty( - {group_size * 3}, + // Allocate Tensors for ptrs of grad output and input, and indices + Tensor args_tensor_small = at::empty( + {count_small * 3}, at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - int64_t* grad_output_ptrs = args_tensor.data_ptr(); - int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; + + Tensor args_tensor_large = at::empty( + {count_large * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensors are contiguous + TORCH_CHECK(args_tensor_small.is_contiguous()); + TORCH_CHECK(args_tensor_large.is_contiguous()); + + int64_t* grad_output_ptrs_small = args_tensor_small.data_ptr(); + int64_t* grad_input_ptrs_small = args_tensor_small.data_ptr() + count_small; + int64_t* indices_ptrs_small = args_tensor_small.data_ptr() + 2 * count_small; + + int64_t* grad_output_ptrs_large = args_tensor_large.data_ptr(); + int64_t* grad_input_ptrs_large = args_tensor_large.data_ptr() + count_large; + int64_t* indices_ptrs_large = args_tensor_large.data_ptr() + 2 * count_large; int64_t group_grad_input_numel = 0; std::vector grad_input_numels; @@ -472,6 +581,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( std::vector> grad_output_contigs; grad_output_contigs.reserve(group_size); + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t idx_small = 0; + int64_t idx_large = 0; + for (const auto i : c10::irange(group_size)) { const auto& grad = grad_output_group[i]; TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); @@ -488,8 +601,15 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( group_grad_input_numel += grad_input_numel; // Put all grad output/input pointers in an array - grad_output_ptrs[i] = - reinterpret_cast(grad_output_contigs[i]->data_ptr()); + if(grad.size(1) < cols_per_warp) { + grad_output_ptrs_small[idx_small] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); + idx_small++; + } else { + grad_output_ptrs_large[idx_large] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); + idx_large++; + } } // Allocate a big tensor to avoid calling many small elementwise kernels @@ -502,6 +622,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( TORCH_CHECK(output_group.size() == static_cast(group_size)); + // Reset the counters of the small and large arguments + idx_small = 0; + idx_large = 0; + // Reshape grad inputs and obtain their pointers for (int i = 0; i < group_size; i++) { const auto grad_input_shape = std::vector( @@ -509,38 +633,76 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_shape_group.begin() + (i + 1) * output_dim); output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - grad_input_ptrs[i] = reinterpret_cast(output_group[i].data_ptr()); + + if(grad_output_group[i].size(1) < cols_per_warp) { + grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); + idx_small++; + } else { + grad_input_ptrs_large[idx_large] = reinterpret_cast(output_group[i].data_ptr()); + idx_large++; + } // 2) Add group_size gradients for inputs outputs.push_back(output_group[i]); } + // Reset the counters of the small and large arguments + idx_small = 0; + idx_large = 0; + // Calculate indices_ptrs std::vector> index_contigs; index_contigs.reserve(group_size); for (const auto i : c10::irange(group_size)) { const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + + if(grad_output_group[i].size(1) < cols_per_warp) { + indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); + idx_small++; + } else { + indices_ptrs_large[idx_large] = reinterpret_cast(index_contigs[i]->data_ptr()); + idx_large++; + } } // Transfer grad output pointers to GPU - args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); - - group_index_select_or_add_cuda( - args_tensor.data_ptr(), - args_tensor.data_ptr() + group_size, - args_tensor.data_ptr() + 2 * group_size, - warp_offsets_group, - num_cols_group, + args_tensor_small = args_tensor_small.to(first_indices.device(), /*non_blocking=*/true); + args_tensor_large = args_tensor_large.to(first_indices.device(), /*non_blocking=*/true); + + if(count_small > 0) { + group_index_select_or_add_cuda_smallEmbD( + args_tensor_small.data_ptr(), + args_tensor_small.data_ptr() + count_small, + args_tensor_small.data_ptr() + 2 * count_small, + warp_offsets_group_small, + num_cols_group_small, fwd_input.scalar_type(), first_indices.scalar_type(), fwd_input.device().index(), num_input_rows, - total_num_warps, + total_num_warps_small, group_size, /*use_index_select=*/false, use_var_cols); + } + + if(count_large > 0) { + group_index_select_or_add_cuda( + args_tensor_large.data_ptr(), + args_tensor_large.data_ptr() + count_large, + args_tensor_large.data_ptr() + 2 * count_large, + warp_offsets_group_large, + num_cols_group_large, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps_large, + group_size, + /*use_index_select=*/false, + use_var_cols_large); + } return outputs; } From ff1b9b6c70f9483cf2ed23e3afbd76c3a599cf9e Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 16 Dec 2025 17:27:58 +0000 Subject: [PATCH 03/25] sparse_group_index.cu: edits some comments --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index a0584b23de..a8eb081610 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -96,9 +96,8 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane - // local_row may be out of bounds for the last few lanes in the warp - // if [COLS_PER_WARP % num_cols != 0] - // TODO: check if current_row < num_work_rows is necessary + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { index_t* indices = reinterpret_cast(indices_ptrs[member_id]); index_t idx = indices[current_row]; From 439a51a567ea6274dbffa4d2614274bc2f82d9e5 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 16 Dec 2025 18:28:08 +0000 Subject: [PATCH 04/25] adds USE_ROCM guards to subwarp optimizations for group_index_select_or_add_2d_kernel --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 10 +++++++--- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 9 +++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index a8eb081610..2c54c7bab1 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -84,6 +84,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } +#ifdef USE_ROCM if (num_cols < COLS_PER_WARP) { // Optimized path for small embedding dimensions // Each warp processes 'rows_per_warp' rows @@ -107,6 +108,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + // Compile time conditional if constexpr (USE_INDEX_SELECT) { output_base[current_row * num_cols + col] = LDG(&input_base[idx * num_cols + col]); @@ -119,12 +121,12 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } else { // Large embedding dimensions use >= 1 warp per row - + // which is the default codepath for non-ROCm as well +#endif // USE_ROCM const auto row = member_warp_id / warps_per_row; const auto col_offset = ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; scalar_t* output = @@ -132,9 +134,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( index_t* indices = reinterpret_cast(indices_ptrs[member_id]); const index_t idx = indices[row]; - #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional if constexpr (USE_INDEX_SELECT) { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { @@ -142,7 +144,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( &output[idx * num_cols + i], input[row * num_cols + i]); } } +#ifdef USE_ROCM } +#endif // USE_ROCM } } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index bdd0f13652..d592cce6a9 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -304,6 +304,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Number of columns can be different auto num_cols_ = input_reshaped_.size(1); +#ifdef USE_ROCM int64_t warps_needed; if (num_cols_ < cols_per_warp) { // Optimization: Pack multiple rows into one warp @@ -314,6 +315,10 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; warps_needed = warps_per_row * num_output_rows_; } +#else + // Standard: One or more warps per row + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; +#endif // USE_ROCM if (num_cols != num_cols_) { use_var_cols = true; @@ -339,7 +344,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offsets_group[i] = warp_offset; num_cols_group[i] = num_cols_; +#ifdef USE_ROCM warp_offset += warps_needed; +#else + warp_offset += warps_per_row * num_output_rows; +#endif // USE_ROCM } // Store the last offset From 2a85d73f669959ff86859c216edae4b41db4f53b Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 18 Dec 2025 10:11:29 +0000 Subject: [PATCH 05/25] sparse_group_index: handle UNROLL_FACTOR for small dimensions in group_index_select_or_add_2d_kernel --- .../src/sparse_ops/sparse_group_index.cu | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 2c54c7bab1..5e87d96961 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -94,28 +94,27 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // Since we are processing multiple rows within the warp, we need to // map each lane to a specific row, in addition to the column int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp - int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - index_t idx = indices[current_row]; - - scalar_t* input_base = reinterpret_cast(input_ptrs[member_id]); - scalar_t* output_base = reinterpret_cast(output_ptrs[member_id]); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[current_row]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional if constexpr (USE_INDEX_SELECT) { - output_base[current_row * num_cols + col] = - LDG(&input_base[idx * num_cols + col]); + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( - &output_base[idx * num_cols + col], - input_base[current_row * num_cols + col]); + &output[idx * num_cols + i], input[current_row * num_cols + i]); } } } From 2f541407b26c53140e857fee4e3cbc217beb5864 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 18 Dec 2025 13:26:01 +0000 Subject: [PATCH 06/25] sparse_group_index: handle fixed-column-size case correctly in optimized small embedding dims path --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 5e87d96961..3f191c6100 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -82,6 +82,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); +#ifdef USE_ROCM + if (num_cols < COLS_PER_WARP) { + // Need to ensure that [member_id] and [member_warp_id] are calculated correctly + // for the small embedding dimension path below + int rows_per_warp = COLS_PER_WARP / num_cols; + auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } +#endif // USE_ROCM } #ifdef USE_ROCM From e0edc4095377302c84f9450ec400800c5a98dcba Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 18 Dec 2025 14:27:06 +0000 Subject: [PATCH 07/25] group_index_select_or_add_2d_kernel: when num_cols < UNROLL_FACTOR, disable optimized smallEmbD path --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 2 ++ fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 8 ++++++-- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index efebf3ac02..6f3f8c246c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1090,6 +1090,8 @@ void group_index_select_or_add_cuda( int get_group_index_select_cols_per_warp(); +int get_group_index_select_unroll_factor(); + std::vector jagged_index_select_2d( const at::Tensor& values, const at::Tensor& lengths, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 3f191c6100..12ed1045d4 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -33,6 +33,10 @@ int get_group_index_select_cols_per_warp() { return GROUP_INDEX_SELECT_COLS_PER_WARP; } +int get_group_index_select_unroll_factor() { + return GROUP_INDEX_SELECT_UNROLL_FACTOR; +} + template < typename index_t, typename scalar_t, @@ -83,7 +87,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); #ifdef USE_ROCM - if (num_cols < COLS_PER_WARP) { + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Need to ensure that [member_id] and [member_warp_id] are calculated correctly // for the small embedding dimension path below int rows_per_warp = COLS_PER_WARP / num_cols; @@ -95,7 +99,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } #ifdef USE_ROCM - if (num_cols < COLS_PER_WARP) { + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Optimized path for small embedding dimensions // Each warp processes 'rows_per_warp' rows int rows_per_warp = COLS_PER_WARP / num_cols; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index d592cce6a9..df3f49af5c 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -265,6 +265,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); const int num_cols = input_reshaped.size(1); const int cols_per_warp = get_group_index_select_cols_per_warp(); + const int unroll_factor = get_group_index_select_unroll_factor(); int64_t warp_offset = 0; bool use_var_cols = false; @@ -306,7 +307,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( #ifdef USE_ROCM int64_t warps_needed; - if (num_cols_ < cols_per_warp) { + if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) { // Optimization: Pack multiple rows into one warp int rows_per_warp = cols_per_warp / num_cols_; warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; From 81bf648002a50ff69d9d4f84295e32d0e189333b Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 7 Jan 2026 14:14:49 +0000 Subject: [PATCH 08/25] sparse_group_index: corrects macro invoking group_index_select_or_add_2d_kernel_small kernel --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index ddbaa89a27..48746825a9 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -372,7 +372,9 @@ DLL_PUBLIC void group_index_select_or_add_cuda_smallEmbD( scalar_t, \ USE_INDEX_SELECT, \ USE_VAR_COLS, \ - GROUP_INDEX_SELECT_COLS_PER_WARP>), \ + GROUP_INDEX_SELECT_UNROLL_FACTOR, \ + GROUP_INDEX_SELECT_COLS_PER_WARP, \ + GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ grid_size, \ block_size, \ 0, \ From 17d8d4ce9201f559f1d3c43f7794f3755e26fbe4 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 12 Jan 2026 13:54:56 +0000 Subject: [PATCH 09/25] fixes merge commit --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 58dd63157c..9a3aa96bca 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -359,7 +359,6 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.push_back(input.expect_contiguous()); index_contigs.push_back(indices.expect_contiguous()); -<<<<<<< HEAD if (num_cols_ < cols_per_warp) { // Optimization for Small Embedding: Pack multiple rows per warp small.input_ptrs[small.count] = reinterpret_cast(input_contigs[i]->data_ptr()); @@ -379,20 +378,6 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( large.total_warps += warps_needed; large.count++; } -======= - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; - -#ifdef USE_ROCM - warp_offset += warps_needed; -#else - warp_offset += warps_per_row * num_output_rows; -#endif // USE_ROCM ->>>>>>> aryaman/group-index-subwarp } // Store the last offset From b6cec91684c03ffdf9796cf6cea23d9786ff6d15 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 12 Jan 2026 14:02:08 +0000 Subject: [PATCH 10/25] sparse_group_index.cu: copies updates into small kernel --- .../src/sparse_ops/sparse_group_index.cu | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 2316beab4a..04ec7b5d06 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -86,9 +86,20 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); +#ifdef USE_ROCM + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { + // Need to ensure that [member_id] and [member_warp_id] are calculated correctly + // for the small embedding dimension path below + int rows_per_warp = COLS_PER_WARP / num_cols; + auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } +#endif // USE_ROCM } - if (num_cols < COLS_PER_WARP) { +#ifdef USE_ROCM + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Optimized path for small embedding dimensions // Each warp processes 'rows_per_warp' rows int rows_per_warp = COLS_PER_WARP / num_cols; @@ -97,39 +108,38 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( // Since we are processing multiple rows within the warp, we need to // map each lane to a specific row, in addition to the column int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp - int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane - // local_row may be out of bounds for the last few lanes in the warp - // if [COLS_PER_WARP % num_cols != 0] - // TODO: check if current_row < num_work_rows is necessary + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - index_t idx = indices[current_row]; - - scalar_t* input_base = reinterpret_cast(input_ptrs[member_id]); - scalar_t* output_base = reinterpret_cast(output_ptrs[member_id]); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[current_row]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional if constexpr (USE_INDEX_SELECT) { - output_base[current_row * num_cols + col] = - LDG(&input_base[idx * num_cols + col]); + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( - &output_base[idx * num_cols + col], - input_base[current_row * num_cols + col]); + &output[idx * num_cols + i], input[current_row * num_cols + i]); } } } } else { // Large embedding dimensions use >= 1 warp per row - + // which is the default codepath for non-ROCm as well +#endif // USE_ROCM const auto row = member_warp_id / warps_per_row; const auto col_offset = ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; scalar_t* output = @@ -137,9 +147,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( index_t* indices = reinterpret_cast(indices_ptrs[member_id]); const index_t idx = indices[row]; - #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional if constexpr (USE_INDEX_SELECT) { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { @@ -147,7 +157,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( &output[idx * num_cols + i], input[row * num_cols + i]); } } +#ifdef USE_ROCM } +#endif // USE_ROCM } } From 4576e59d018442aa7da188570b603c068ebe3be6 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 14 Jan 2026 12:17:26 +0000 Subject: [PATCH 11/25] sparse_ops_gpu.cpp: corrects handling of forward outputs and adds debug code --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 52 +++++++++++++++----- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 9a3aa96bca..d522fa7852 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -327,7 +327,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( #ifdef USE_ROCM int64_t warps_needed; - if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) { + if (num_cols_ < 0 && num_cols_ >= unroll_factor) { // Optimization: Pack multiple rows into one warp int rows_per_warp = cols_per_warp / num_cols_; warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; @@ -359,7 +359,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.push_back(input.expect_contiguous()); index_contigs.push_back(indices.expect_contiguous()); - if (num_cols_ < cols_per_warp) { + if (num_cols_ < 0) { // Optimization for Small Embedding: Pack multiple rows per warp small.input_ptrs[small.count] = reinterpret_cast(input_contigs[i]->data_ptr()); small.output_ptrs[small.count] = reinterpret_cast(output.data_ptr()); @@ -494,15 +494,40 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 4); - // all_input size = group_size * 2 (from grads, indices) - // + 2 args_tensor2 + 2 saved_data + 1 first input - const int64_t group_size = (all_inputs.size() - 5) / 2; + std::cout << "\n=== DEBUG BACKWARD PASS INPUTS ===" << std::endl; + std::cout << "Total all_inputs.size(): " << all_inputs.size() << std::endl; + + // Print metadata for each input tensor + for (int i = 0; i < all_inputs.size(); ++i) { + const auto& t = all_inputs[i]; + std::cout << "Input [" << i << "]: "; + if (t.defined()) { + std::cout << "Device=" << t.device() + << ", Shape=" << t.sizes() + << ", Dtype=" << t.dtype(); + } else { + std::cout << "UNDEFINED (None)"; + } + std::cout << std::endl; + } + std::cout << "==================================\n" << std::endl; - const Tensor& fwd_input = all_inputs[2 * group_size + 4]; - const Tensor& saved_data_small = all_inputs[2 * group_size + 2]; - const Tensor& saved_data_large = all_inputs[2 * group_size + 3]; + // Layout: [Grads(G) | AuxGrads(2) | Indices(G) | ArgsS, ArgsL, SavedS, SavedL, FwdInput(5)] + // Total = 2*G + 7 + TORCH_CHECK(all_inputs.size() > 7); + + // all_input size = group_size * 2 (from grads, indices) + // + 2 args_tensors + 2 gradient vectors generated by PyTorch corresponding to the args_tensors + // + 2 saved_data + 1 first input + const int64_t group_size = (all_inputs.size() - 7) / 2; + + const size_t n = all_inputs.size(); + const Tensor& fwd_input = all_inputs[n - 1]; + const Tensor& saved_data_large = all_inputs[n - 2]; + const Tensor& saved_data_small = all_inputs[n - 3]; + // const Tensor& args_tensor_large = all_inputs[n - 4]; + // const Tensor& args_tensor_small = all_inputs[n - 5]; const Tensor& first_indices = all_inputs[group_size]; const int64_t output_dim = fwd_input.dim(); @@ -515,8 +540,9 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_shape_group.push_back(i.as_int_unchecked()); } + // Skip G grads + 2 aux grads. Start indices at G + 2. auto indices_group = std::vector( - all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + all_inputs.cbegin() + group_size + 2, all_inputs.cbegin() + 2 + 2 * group_size); // Retrieve saved data TORCH_CHECK(saved_data_small.device() == at::kCPU); @@ -607,7 +633,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( group_grad_input_numel += grad_input_numel; // Put all grad output/input pointers in an array - if(grad.size(1) < cols_per_warp) { + if(grad.size(1) < 0) { grad_output_ptrs_small[idx_small] = reinterpret_cast(grad_output_contigs[i]->data_ptr()); idx_small++; @@ -640,7 +666,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - if(grad_output_group[i].size(1) < cols_per_warp) { + if(grad_output_group[i].size(1) < 0) { grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); idx_small++; } else { @@ -663,7 +689,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - if(grad_output_group[i].size(1) < cols_per_warp) { + if(grad_output_group[i].size(1) < 0) { indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); idx_small++; } else { From 71be9d215686106af36b999ca623cb35eea2afc6 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 14 Jan 2026 14:34:44 +0000 Subject: [PATCH 12/25] sparse_ops_cpu: adjusts to handle potentially higher amount of saved data between forward and backward passes --- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index ba58691e8d..773ae183fc 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3512,7 +3512,7 @@ torch::autograd::variable_list group_index_select_dim0( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() == group_size + 2); + TORCH_CHECK(res.size() >= group_size + 2); // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3623,7 +3623,8 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); + TORCH_CHECK(static_cast(result.size()) >= group_size + 2); + ctx->saved_data["group_size"] = group_size; auto [input_group, indices_group] = group_index_select_dim0_unpack(all_indices_input, group_size); @@ -3656,17 +3657,19 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( torch::autograd::variable_list GroupIndexSelectDim0Op::backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { + + auto group_size = ctx->saved_data["group_size"].toInt(); + TORCH_CHECK(static_cast(grad_output_group.size()) >= group_size); + + if (group_size == 0) { // empty outputs return torch::autograd::variable_list(1); } // remove redundant grads - auto group_size = grad_output_group.size() - 2; grad_output_group.resize(group_size); auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); + TORCH_CHECK(saved_tensors.size() >= group_size + 3); std::vector output_shape_group; int i = 0; while (true) { From e5dcf52c2085188cb9782c6b5fbfd94e071ccd18 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 14 Jan 2026 14:37:22 +0000 Subject: [PATCH 13/25] Revert "sparse_ops_gpu.cpp: corrects handling of forward outputs and adds debug code" This reverts commit 4576e59d018442aa7da188570b603c068ebe3be6. --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 52 +++++--------------- 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index d522fa7852..9a3aa96bca 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -327,7 +327,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( #ifdef USE_ROCM int64_t warps_needed; - if (num_cols_ < 0 && num_cols_ >= unroll_factor) { + if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) { // Optimization: Pack multiple rows into one warp int rows_per_warp = cols_per_warp / num_cols_; warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; @@ -359,7 +359,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.push_back(input.expect_contiguous()); index_contigs.push_back(indices.expect_contiguous()); - if (num_cols_ < 0) { + if (num_cols_ < cols_per_warp) { // Optimization for Small Embedding: Pack multiple rows per warp small.input_ptrs[small.count] = reinterpret_cast(input_contigs[i]->data_ptr()); small.output_ptrs[small.count] = reinterpret_cast(output.data_ptr()); @@ -494,40 +494,15 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { - - std::cout << "\n=== DEBUG BACKWARD PASS INPUTS ===" << std::endl; - std::cout << "Total all_inputs.size(): " << all_inputs.size() << std::endl; - - // Print metadata for each input tensor - for (int i = 0; i < all_inputs.size(); ++i) { - const auto& t = all_inputs[i]; - std::cout << "Input [" << i << "]: "; - if (t.defined()) { - std::cout << "Device=" << t.device() - << ", Shape=" << t.sizes() - << ", Dtype=" << t.dtype(); - } else { - std::cout << "UNDEFINED (None)"; - } - std::cout << std::endl; - } - std::cout << "==================================\n" << std::endl; - - // Layout: [Grads(G) | AuxGrads(2) | Indices(G) | ArgsS, ArgsL, SavedS, SavedL, FwdInput(5)] - // Total = 2*G + 7 - TORCH_CHECK(all_inputs.size() > 7); + TORCH_CHECK(all_inputs.size() > 4); // all_input size = group_size * 2 (from grads, indices) - // + 2 args_tensors + 2 gradient vectors generated by PyTorch corresponding to the args_tensors - // + 2 saved_data + 1 first input - const int64_t group_size = (all_inputs.size() - 7) / 2; - - const size_t n = all_inputs.size(); - const Tensor& fwd_input = all_inputs[n - 1]; - const Tensor& saved_data_large = all_inputs[n - 2]; - const Tensor& saved_data_small = all_inputs[n - 3]; - // const Tensor& args_tensor_large = all_inputs[n - 4]; - // const Tensor& args_tensor_small = all_inputs[n - 5]; + // + 2 args_tensor2 + 2 saved_data + 1 first input + const int64_t group_size = (all_inputs.size() - 5) / 2; + + const Tensor& fwd_input = all_inputs[2 * group_size + 4]; + const Tensor& saved_data_small = all_inputs[2 * group_size + 2]; + const Tensor& saved_data_large = all_inputs[2 * group_size + 3]; const Tensor& first_indices = all_inputs[group_size]; const int64_t output_dim = fwd_input.dim(); @@ -540,9 +515,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_shape_group.push_back(i.as_int_unchecked()); } - // Skip G grads + 2 aux grads. Start indices at G + 2. auto indices_group = std::vector( - all_inputs.cbegin() + group_size + 2, all_inputs.cbegin() + 2 + 2 * group_size); + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); // Retrieve saved data TORCH_CHECK(saved_data_small.device() == at::kCPU); @@ -633,7 +607,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( group_grad_input_numel += grad_input_numel; // Put all grad output/input pointers in an array - if(grad.size(1) < 0) { + if(grad.size(1) < cols_per_warp) { grad_output_ptrs_small[idx_small] = reinterpret_cast(grad_output_contigs[i]->data_ptr()); idx_small++; @@ -666,7 +640,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - if(grad_output_group[i].size(1) < 0) { + if(grad_output_group[i].size(1) < cols_per_warp) { grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); idx_small++; } else { @@ -689,7 +663,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - if(grad_output_group[i].size(1) < 0) { + if(grad_output_group[i].size(1) < cols_per_warp) { indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); idx_small++; } else { From f24a6cd4d816f70ef4497b9ff48f5a17088d70bb Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 14 Jan 2026 16:55:42 +0000 Subject: [PATCH 14/25] sparse_ops_gpu.cpp: use different for small and large embedding functions --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 39 ++++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 9a3aa96bca..c287c1832b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -287,7 +287,14 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( const int cols_per_warp = get_group_index_select_cols_per_warp(); const int unroll_factor = get_group_index_select_unroll_factor(); int64_t warp_offset = 0; - bool use_var_cols = false; + bool use_var_cols_small = false; + bool use_var_cols_large = false; + + bool first_small_table = true; + bool first_large_table = true; + + int prev_num_cols_small; + int prev_num_cols_large; // Allocate memory for output_group std::vector output_group; @@ -341,11 +348,6 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; #endif // USE_ROCM - // TODO: maintain [use_var_cols] separately for small emb dims - if (num_cols != num_cols_) { - use_var_cols = true; - } - // Create output pointers auto input_shape = input.sizes().vec(); input_shape[0] = num_output_rows_; @@ -361,6 +363,12 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( if (num_cols_ < cols_per_warp) { // Optimization for Small Embedding: Pack multiple rows per warp + + if(!first_small_table && num_cols_ != prev_num_cols_small) { + use_var_cols_small = true; + } + first_small_table = false; + prev_num_cols_small = num_cols_; small.input_ptrs[small.count] = reinterpret_cast(input_contigs[i]->data_ptr()); small.output_ptrs[small.count] = reinterpret_cast(output.data_ptr()); small.indices_ptrs[small.count] = reinterpret_cast(index_contigs[i]->data_ptr()); @@ -370,6 +378,13 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( small.count++; } else { // Standard Embedding: One or more warps per row + + if(!first_large_table && num_cols_ != prev_num_cols_large) { + use_var_cols_large = true; + } + first_large_table = false; + prev_num_cols_large = num_cols_; + large.input_ptrs[large.count] = reinterpret_cast(input_contigs[i]->data_ptr()); large.output_ptrs[large.count] = reinterpret_cast(output.data_ptr()); large.indices_ptrs[large.count] = reinterpret_cast(index_contigs[i]->data_ptr()); @@ -423,7 +438,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t saved_data_small[] = { static_cast(small.count), - use_var_cols, + use_var_cols_small, reinterpret_cast(small.warp_offsets_group), reinterpret_cast(small.num_cols_group), small.total_warps, @@ -436,7 +451,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t saved_data_large[] = { static_cast(large.count), - use_var_cols, + use_var_cols_large, reinterpret_cast(large.warp_offsets_group), reinterpret_cast(large.num_cols_group), large.total_warps, @@ -461,7 +476,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( /*total_num_warps=*/small.total_warps, small.count, /*use_index_select=*/true, - use_var_cols); + use_var_cols_small); } if(large.count > 0) { @@ -478,7 +493,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( /*total_num_warps=*/large.total_warps, large.count, /*use_index_select=*/true, - use_var_cols); + use_var_cols_large); } output_group.push_back(args_tensor_small); @@ -523,7 +538,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( TORCH_CHECK(saved_data_small.is_contiguous()); int64_t* saved_data_small_ptr = saved_data_small.data_ptr(); auto count_small = saved_data_small_ptr[0]; - const bool use_var_cols = saved_data_small_ptr[1]; + const bool use_var_cols_small = saved_data_small_ptr[1]; int64_t* warp_offsets_group_small = reinterpret_cast(saved_data_small_ptr[2]); int32_t* num_cols_group_small = reinterpret_cast(saved_data_small_ptr[3]); int64_t total_num_warps_small = saved_data_small_ptr[4]; @@ -690,7 +705,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( total_num_warps_small, group_size, /*use_index_select=*/false, - use_var_cols); + use_var_cols_small); } if(count_large > 0) { From 39afe2802d797763f49e11c354ac48b2423c6f77 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 15 Jan 2026 14:30:50 +0000 Subject: [PATCH 15/25] sparse_ops_gpu.cpp: corrects computation of for multi-D inputs in backward function --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index c287c1832b..3cb757daf1 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -621,8 +621,11 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_numels.push_back(grad_input_numel); group_grad_input_numel += grad_input_numel; + auto grad_reshaped = grad.reshape({grad.size(0), -1}); + auto num_cols_ = grad_reshaped.size(1); + // Put all grad output/input pointers in an array - if(grad.size(1) < cols_per_warp) { + if(num_cols_ < cols_per_warp) { grad_output_ptrs_small[idx_small] = reinterpret_cast(grad_output_contigs[i]->data_ptr()); idx_small++; @@ -655,7 +658,9 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - if(grad_output_group[i].size(1) < cols_per_warp) { + auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + auto num_cols_ = output_group_reshaped.size(1); + if(num_cols_ < cols_per_warp) { grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); idx_small++; } else { @@ -678,7 +683,9 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - if(grad_output_group[i].size(1) < cols_per_warp) { + auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + auto num_cols_ = grad_output_reshaped.size(1); + if(num_cols_ < cols_per_warp) { indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); idx_small++; } else { @@ -703,7 +710,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( fwd_input.device().index(), num_input_rows, total_num_warps_small, - group_size, + count_small, /*use_index_select=*/false, use_var_cols_small); } @@ -720,7 +727,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( fwd_input.device().index(), num_input_rows, total_num_warps_large, - group_size, + count_large, /*use_index_select=*/false, use_var_cols_large); } From 10b692db62e522e32db32bec9544391720917a83 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 15 Jan 2026 15:15:33 +0000 Subject: [PATCH 16/25] Revert "sparse_ops_gpu.cpp: corrects computation of for multi-D inputs in backward function" This reverts commit 39afe2802d797763f49e11c354ac48b2423c6f77. --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 3cb757daf1..c287c1832b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -621,11 +621,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_numels.push_back(grad_input_numel); group_grad_input_numel += grad_input_numel; - auto grad_reshaped = grad.reshape({grad.size(0), -1}); - auto num_cols_ = grad_reshaped.size(1); - // Put all grad output/input pointers in an array - if(num_cols_ < cols_per_warp) { + if(grad.size(1) < cols_per_warp) { grad_output_ptrs_small[idx_small] = reinterpret_cast(grad_output_contigs[i]->data_ptr()); idx_small++; @@ -658,9 +655,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); - auto num_cols_ = output_group_reshaped.size(1); - if(num_cols_ < cols_per_warp) { + if(grad_output_group[i].size(1) < cols_per_warp) { grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); idx_small++; } else { @@ -683,9 +678,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); - auto num_cols_ = grad_output_reshaped.size(1); - if(num_cols_ < cols_per_warp) { + if(grad_output_group[i].size(1) < cols_per_warp) { indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); idx_small++; } else { @@ -710,7 +703,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( fwd_input.device().index(), num_input_rows, total_num_warps_small, - count_small, + group_size, /*use_index_select=*/false, use_var_cols_small); } @@ -727,7 +720,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( fwd_input.device().index(), num_input_rows, total_num_warps_large, - count_large, + group_size, /*use_index_select=*/false, use_var_cols_large); } From b9ae864ecb4498be4b07f521c8c97687f2f50829 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 15 Jan 2026 15:25:54 +0000 Subject: [PATCH 17/25] sparse_ops_gpu: corrects group_size arguments for split kernels --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index c287c1832b..148b596965 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -703,7 +703,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( fwd_input.device().index(), num_input_rows, total_num_warps_small, - group_size, + count_small, /*use_index_select=*/false, use_var_cols_small); } @@ -720,7 +720,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( fwd_input.device().index(), num_input_rows, total_num_warps_large, - group_size, + count_large, /*use_index_select=*/false, use_var_cols_large); } From c54af4f38f5406ce5a986a84d2d8c3e152a6393d Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 15 Jan 2026 14:30:50 +0000 Subject: [PATCH 18/25] sparse_ops_gpu.cpp: corrects computation of for multi-D inputs in backward function --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 148b596965..3cb757daf1 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -621,8 +621,11 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_numels.push_back(grad_input_numel); group_grad_input_numel += grad_input_numel; + auto grad_reshaped = grad.reshape({grad.size(0), -1}); + auto num_cols_ = grad_reshaped.size(1); + // Put all grad output/input pointers in an array - if(grad.size(1) < cols_per_warp) { + if(num_cols_ < cols_per_warp) { grad_output_ptrs_small[idx_small] = reinterpret_cast(grad_output_contigs[i]->data_ptr()); idx_small++; @@ -655,7 +658,9 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - if(grad_output_group[i].size(1) < cols_per_warp) { + auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + auto num_cols_ = output_group_reshaped.size(1); + if(num_cols_ < cols_per_warp) { grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); idx_small++; } else { @@ -678,7 +683,9 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - if(grad_output_group[i].size(1) < cols_per_warp) { + auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + auto num_cols_ = grad_output_reshaped.size(1); + if(num_cols_ < cols_per_warp) { indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); idx_small++; } else { From 65a3f841c35347984fffe945b31891101f100f70 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 15 Jan 2026 18:22:25 +0000 Subject: [PATCH 19/25] sparse_ops_gpu: use const for temporary variables --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 3cb757daf1..2e8dc634ad 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -621,8 +621,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_numels.push_back(grad_input_numel); group_grad_input_numel += grad_input_numel; - auto grad_reshaped = grad.reshape({grad.size(0), -1}); - auto num_cols_ = grad_reshaped.size(1); + const auto grad_reshaped = grad.reshape({grad.size(0), -1}); + const auto num_cols_ = grad_reshaped.size(1); // Put all grad output/input pointers in an array if(num_cols_ < cols_per_warp) { @@ -658,8 +658,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( output_group[i] = output_group[i].reshape(grad_input_shape); TORCH_CHECK(output_group[i].is_contiguous()); - auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); - auto num_cols_ = output_group_reshaped.size(1); + const auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + const auto num_cols_ = output_group_reshaped.size(1); if(num_cols_ < cols_per_warp) { grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].data_ptr()); idx_small++; @@ -683,8 +683,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); - auto num_cols_ = grad_output_reshaped.size(1); + const auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + const auto num_cols_ = grad_output_reshaped.size(1); if(num_cols_ < cols_per_warp) { indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->data_ptr()); idx_small++; From ab14c59abca53b7691643c72f73ca95d2dbc4748 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 15 Jan 2026 18:24:36 +0000 Subject: [PATCH 20/25] sparse_group_index: prunes code in small and large group index select kernels --- .../src/sparse_ops/sparse_group_index.cu | 171 +++++------------- 1 file changed, 41 insertions(+), 130 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 04ec7b5d06..97e3b7bab4 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -84,82 +84,44 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same - member_id = warp_id / (warps_per_row * num_work_rows); - member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); -#ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Need to ensure that [member_id] and [member_warp_id] are calculated correctly - // for the small embedding dimension path below - int rows_per_warp = COLS_PER_WARP / num_cols; - auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; - member_id = warp_id / warps_per_member; - member_warp_id = warp_id % warps_per_member; - } -#endif // USE_ROCM + int rows_per_warp = COLS_PER_WARP / num_cols; + auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } -#ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Optimized path for small embedding dimensions - // Each warp processes 'rows_per_warp' rows - int rows_per_warp = COLS_PER_WARP / num_cols; - int64_t start_row = member_warp_id * rows_per_warp; - - // Since we are processing multiple rows within the warp, we need to - // map each lane to a specific row, in addition to the column - int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp - int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; - int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane - - // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] - // and we also need to confirm that we are within num_work_rows - if (local_row < rows_per_warp && current_row < num_work_rows) { - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[current_row]; -#pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[current_row * num_cols + i]); - } - } - } - } else { - // Large embedding dimensions use >= 1 warp per row - // which is the default codepath for non-ROCm as well -#endif // USE_ROCM - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = + // Each warp processes 'rows_per_warp' rows + int rows_per_warp = COLS_PER_WARP / num_cols; + int64_t start_row = member_warp_id * rows_per_warp; + + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows + if (local_row < rows_per_warp && current_row < num_work_rows) { + scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = + scalar_t* output = reinterpret_cast(output_ptrs[member_id]) + col_offset; index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + const index_t idx = indices[current_row]; #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); + &output[idx * num_cols + i], input[current_row * num_cols + i]); } } -#ifdef USE_ROCM } -#endif // USE_ROCM } } @@ -212,80 +174,29 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); -#ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Need to ensure that [member_id] and [member_warp_id] are calculated correctly - // for the small embedding dimension path below - int rows_per_warp = COLS_PER_WARP / num_cols; - auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; - member_id = warp_id / warps_per_member; - member_warp_id = warp_id % warps_per_member; - } -#endif // USE_ROCM } -#ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Optimized path for small embedding dimensions - // Each warp processes 'rows_per_warp' rows - int rows_per_warp = COLS_PER_WARP / num_cols; - int64_t start_row = member_warp_id * rows_per_warp; - - // Since we are processing multiple rows within the warp, we need to - // map each lane to a specific row, in addition to the column - int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp - int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; - int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane - - // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] - // and we also need to confirm that we are within num_work_rows - if (local_row < rows_per_warp && current_row < num_work_rows) { - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[current_row]; -#pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[current_row * num_cols + i]); - } - } - } - } else { - // Large embedding dimensions use >= 1 warp per row - // which is the default codepath for non-ROCm as well -#endif // USE_ROCM - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); - } + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if constexpr (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); } -#ifdef USE_ROCM } -#endif // USE_ROCM } } From 598b2def1cc510416649e9dc71437e2969ef0243 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 9 Dec 2025 21:47:23 -0800 Subject: [PATCH 21/25] Revert D87922263 (#5207) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2206 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5207 - Revert the changes made in D87922263 Reviewed By: cthi, atalman, huydhn Differential Revision: D88774663 fbshipit-source-id: ecc0486eb82564ebc31eac503c58a35600816548 --- .../common/include/fbgemm_gpu/quantize/tuning_cache.cuh | 4 ---- .../merge_pooled_embedding_ops_gpu.cpp | 5 ----- 2 files changed, 9 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.cuh index bc5438783e..c86d78fe78 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.cuh @@ -279,11 +279,7 @@ class TuningCache final { constexpr static std::string_view FBGEMM_CACHE_DIR = ".fbgemm"; -#if !defined(FBGEMM_FBCODE) && ROCM_VERSION >= 70000 - using GPUEvent = at::hip::HIPEvent; -#else using GPUEvent = at::cuda::CUDAEvent; -#endif GPUEvent start_ = GPUEvent(cudaEventDefault); GPUEvent stop_ = GPUEvent(cudaEventDefault); diff --git a/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp index 66585481ad..e483a2893b 100644 --- a/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp @@ -27,13 +27,8 @@ using Tensor = at::Tensor; -#if !defined(FBGEMM_FBCODE) && ROCM_VERSION >= 70000 -using GPUEvent = at::hip::HIPEvent; -#define getCurrentGPUStream at::hip::getCurrentHIPStream -#else using GPUEvent = at::cuda::CUDAEvent; #define getCurrentGPUStream at::cuda::getCurrentCUDAStream -#endif namespace { struct DirectConnectedPeer { From 6a3fef150791ab5e5fa2117010d32a2bcec8e1d3 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 30 Jan 2026 16:02:46 +0000 Subject: [PATCH 22/25] combines split group_index_select_or_add_2d_kernel kernel with compile-time split --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 18 +- .../src/sparse_ops/sparse_group_index.cu | 270 ++++++------------ fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 16 +- 3 files changed, 99 insertions(+), 205 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 82e09ccc28..aaf189721c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1086,22 +1086,8 @@ void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols); - -void group_index_select_or_add_cuda_smallEmbD( - const int64_t* input_ptrs, - const int64_t* output_ptrs, - const int64_t* indices_ptrs, - const int64_t* warp_offsets_group, - const int32_t* num_cols_group, - const c10::ScalarType& input_scalar_type, - const c10::ScalarType& indices_scalar_type, - const c10::DeviceIndex& device, - const int num_work_rows, - const int64_t total_num_warps, - const int group_size, - const bool use_index_select, - const bool use_var_cols); + const bool use_var_cols, + const bool use_small_emb_dim); int get_group_index_select_cols_per_warp(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 97e3b7bab4..3860c73eeb 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -42,11 +42,12 @@ template < typename scalar_t, bool USE_INDEX_SELECT, bool USE_VAR_COLS, + bool USE_SMALL_EMB_DIM, int UNROLL_FACTOR, int COLS_PER_WARP, int LOG_COLS_PER_WARP> __global__ -__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( +__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, @@ -84,122 +85,79 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel_small( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same - int rows_per_warp = COLS_PER_WARP / num_cols; - auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; - member_id = warp_id / warps_per_member; - member_warp_id = warp_id % warps_per_member; - + if constexpr (USE_SMALL_EMB_DIM) { + // Small embedding: pack multiple rows per warp + int rows_per_warp = COLS_PER_WARP / num_cols; + auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } else { + // Large embedding: one or more warps per row + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } } - // Each warp processes 'rows_per_warp' rows - int rows_per_warp = COLS_PER_WARP / num_cols; - int64_t start_row = member_warp_id * rows_per_warp; - - // Since we are processing multiple rows within the warp, we need to - // map each lane to a specific row, in addition to the column - int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp - int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; - int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane - - // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] - // and we also need to confirm that we are within num_work_rows - if (local_row < rows_per_warp && current_row < num_work_rows) { - scalar_t* input = + if constexpr (USE_SMALL_EMB_DIM) { + // Small embedding dimension: pack multiple rows per warp + // Each warp processes 'rows_per_warp' rows + int rows_per_warp = COLS_PER_WARP / num_cols; + int64_t start_row = member_warp_id * rows_per_warp; + + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows + if (local_row < rows_per_warp && current_row < num_work_rows) { + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[current_row]; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if constexpr (USE_INDEX_SELECT) { + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[current_row * num_cols + i]); + } + } + } + } else { + // Large embedding dimension: one or more warps per row + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = + scalar_t* output = reinterpret_cast(output_ptrs[member_id]) + col_offset; index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[current_row]; + const index_t idx = indices[row]; #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional if constexpr (USE_INDEX_SELECT) { - output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[current_row * num_cols + i]); + &output[idx * num_cols + i], input[row * num_cols + i]); } } } } } -template < - typename index_t, - typename scalar_t, - bool USE_INDEX_SELECT, - bool USE_VAR_COLS, - int UNROLL_FACTOR, - int COLS_PER_WARP, - int LOG_COLS_PER_WARP> -__global__ -__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( - const int64_t* input_ptrs, - const int64_t* output_ptrs, - const int64_t* indices_ptrs, - const int64_t* warp_offsets_group, - const int32_t* num_cols_group, - const int64_t num_work_rows, // number of rows to work on per member - const int64_t group_size) { - const auto total_num_warps = warp_offsets_group[group_size]; - int32_t num_cols = 0; - int32_t warps_per_row = 0; - - if constexpr (!USE_VAR_COLS) { - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - } - - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; - warp_id < total_num_warps; - warp_id += gridDim.x * blockDim.y) { - int32_t member_id = 0; - int32_t member_warp_id = 0; - if constexpr (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; - if (threadIdx.x == 0) { - binary_search_range( - &member_ids[threadIdx.y], - warp_offsets_group + 1, - warp_id, - group_size); - } - syncwarp(); - member_id = member_ids[threadIdx.y]; - num_cols = num_cols_group[member_id]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - member_warp_id = warp_id - warp_offsets_group[member_id]; - } else { - // All columns are the same - member_id = warp_id / (warps_per_row * num_work_rows); - member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); - } - - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; -#pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); - } - } - } -} - DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t* input_ptrs, const int64_t* output_ptrs, @@ -213,7 +171,8 @@ DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols) { + const bool use_var_cols, + const bool use_small_emb_dim) { if (group_size == 0) { return; } @@ -229,13 +188,15 @@ DLL_PUBLIC void group_index_select_or_add_cuda( max_grid_size); dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); -#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ +// Kernel launch macro with USE_SMALL_EMB_DIM template parameter +#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS, USE_SMALL_EMB_DIM) \ FBGEMM_LAUNCH_KERNEL( \ (group_index_select_or_add_2d_kernel< \ index_t, \ scalar_t, \ USE_INDEX_SELECT, \ USE_VAR_COLS, \ + USE_SMALL_EMB_DIM, \ GROUP_INDEX_SELECT_UNROLL_FACTOR, \ GROUP_INDEX_SELECT_COLS_PER_WARP, \ GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ @@ -255,17 +216,35 @@ DLL_PUBLIC void group_index_select_or_add_cuda( indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( input_scalar_type, "group_index_select_2d_wrapper_2", [&] { - if (use_index_select) { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true); + if (use_small_emb_dim) { + // Small embedding dimension: pack multiple rows per warp + if (use_index_select) { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, true); + } } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false); + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, true); + } } } else { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true); + // Large embedding dimension: one or more warps per row + if (use_index_select) { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, false); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, false); + } } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false); + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, false); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, false); + } } } }); @@ -274,79 +253,4 @@ DLL_PUBLIC void group_index_select_or_add_cuda( #undef INVOKE_GROUP_INDEX_SELECT_OR_ADD } -DLL_PUBLIC void group_index_select_or_add_cuda_smallEmbD( - const int64_t* input_ptrs, - const int64_t* output_ptrs, - const int64_t* indices_ptrs, - const int64_t* warp_offsets_group, - const int32_t* num_cols_group, - const c10::ScalarType& input_scalar_type, - const c10::ScalarType& indices_scalar_type, - const c10::DeviceIndex& device, - const int num_work_rows, - const int64_t total_num_warps, - const int group_size, - const bool use_index_select, - const bool use_var_cols) { - if (group_size == 0) { - return; - } - - at::cuda::OptionalCUDAGuard device_guard(device); - - // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; - uint32_t max_grid_size = - at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; - uint32_t grid_size = std::min( - cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), - max_grid_size); - dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); - - // Launcher Macro for Small Kernel -#define INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(USE_INDEX_SELECT, USE_VAR_COLS) \ - FBGEMM_LAUNCH_KERNEL( \ - (group_index_select_or_add_2d_kernel_small< \ - index_t, \ - scalar_t, \ - USE_INDEX_SELECT, \ - USE_VAR_COLS, \ - GROUP_INDEX_SELECT_UNROLL_FACTOR, \ - GROUP_INDEX_SELECT_COLS_PER_WARP, \ - GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ - grid_size, \ - block_size, \ - 0, \ - at::cuda::getCurrentCUDAStream(), \ - input_ptrs, \ - output_ptrs, \ - indices_ptrs, \ - warp_offsets_group, \ - num_cols_group, \ - num_work_rows, \ - group_size) - - AT_DISPATCH_INDEX_TYPES( - indices_scalar_type, "group_index_select_2d_small_wrapper_1", [&] { - FBGEMM_DISPATCH_FLOATING_TYPES( - input_scalar_type, "group_index_select_2d_small_wrapper_2", [&] { - if (use_index_select) { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(true, true); - } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(true, false); - } - } else { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(false, true); - } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL(false, false); - } - } - }); - }); - -#undef INVOKE_GROUP_INDEX_SELECT_OR_ADD_SMALL -} - } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 2e8dc634ad..990dba57ec 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -463,7 +463,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( memcpy(saved_data_t_large.data_ptr(), saved_data_large, sizeof(saved_data_large)); if(small.count > 0) { - group_index_select_or_add_cuda_smallEmbD( + group_index_select_or_add_cuda( small.input_ptrs, small.output_ptrs, small.indices_ptrs, @@ -476,7 +476,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( /*total_num_warps=*/small.total_warps, small.count, /*use_index_select=*/true, - use_var_cols_small); + use_var_cols_small, + /*use_small_emb_dim=*/true); } if(large.count > 0) { @@ -493,7 +494,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( /*total_num_warps=*/large.total_warps, large.count, /*use_index_select=*/true, - use_var_cols_large); + use_var_cols_large, + /*use_small_emb_dim=*/false); } output_group.push_back(args_tensor_small); @@ -699,7 +701,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( args_tensor_large = args_tensor_large.to(first_indices.device(), /*non_blocking=*/true); if(count_small > 0) { - group_index_select_or_add_cuda_smallEmbD( + group_index_select_or_add_cuda( args_tensor_small.data_ptr(), args_tensor_small.data_ptr() + count_small, args_tensor_small.data_ptr() + 2 * count_small, @@ -712,7 +714,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( total_num_warps_small, count_small, /*use_index_select=*/false, - use_var_cols_small); + use_var_cols_small, + /*use_small_emb_dim=*/true); } if(count_large > 0) { @@ -729,7 +732,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( total_num_warps_large, count_large, /*use_index_select=*/false, - use_var_cols_large); + use_var_cols_large, + /*use_small_emb_dim=*/false); } return outputs; From adb728e0b7e0b3ee6fa06c7a2f8a670fa30a143f Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 2 Feb 2026 14:18:57 +0000 Subject: [PATCH 23/25] fixes merge issues --- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 17 ++++------------- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 6 ++---- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 76619e7d8a..62835a537e 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3508,14 +3508,10 @@ torch::autograd::variable_list group_index_select_dim0( at::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); -<<<<<<< HEAD - auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() >= group_size + 2); -======= + auto res = forward_op.call( - all_indices_input_tensor, static_cast(group_size)); - TORCH_CHECK(res.size() == group_size + 2); ->>>>>>> origin/main_12162025_upstream + all_indices_input_tensor, static_cast(group_size)); + TORCH_CHECK(res.size() >= group_size + 2); // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3671,13 +3667,8 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::backward( // remove redundant grads grad_output_group.resize(group_size); -<<<<<<< HEAD - auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() >= group_size + 3); -======= const auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); ->>>>>>> origin/main_12162025_upstream + TORCH_CHECK(saved_tensors.size() >= group_size + 3); std::vector output_shape_group; int i = 0; while (true) { diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index f784cb5484..3ff429c786 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -569,7 +569,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Retrieve saved data TORCH_CHECK(saved_data_small.device() == at::kCPU, "Tensor saved_data_small must be on CPU."); TORCH_CHECK(saved_data_small.is_contiguous(), "Tensor saved_data_small must be contiguous."); - int64_t* saved_data_small_ptr = saved_data_small.const_data_ptr(); + const int64_t* saved_data_small_ptr = saved_data_small.const_data_ptr(); auto count_small = saved_data_small_ptr[0]; const bool use_var_cols_small = saved_data_small_ptr[1]; int64_t* warp_offsets_group_small = reinterpret_cast(saved_data_small_ptr[2]); @@ -578,7 +578,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( TORCH_CHECK(saved_data_large.device() == at::kCPU, "Tensor saved_data_large must be on CPU."); TORCH_CHECK(saved_data_large.is_contiguous(), "Tensor saved_data_large must be contiguous."); - int64_t* saved_data_large_ptr = saved_data_large.const_data_ptr(); + const int64_t* saved_data_large_ptr = saved_data_large.const_data_ptr(); auto count_large = saved_data_large_ptr[0]; const bool use_var_cols_large = saved_data_large_ptr[1]; int64_t* warp_offsets_group_large = reinterpret_cast(saved_data_large_ptr[2]); @@ -711,8 +711,6 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( " of ", group_size, " must be contiguous."); - grad_input_ptrs[i] = - reinterpret_cast(output_group[i].const_data_ptr()); const auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); const auto num_cols_ = output_group_reshaped.size(1); From 9310b72f1bff94dd833f59dafadf9aeb8ae959b6 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 2 Feb 2026 15:23:45 +0000 Subject: [PATCH 24/25] guard changes with USE_ROCM --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 38 ++- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 7 +- .../src/sparse_ops/sparse_group_index.cu | 74 +++++- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 18 +- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 238 +++++++++++++++++- 5 files changed, 340 insertions(+), 35 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index e726689084..b48a159963 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -982,23 +982,37 @@ def group_index_select_dim0_gpu_impl_abstract( # divide by 2 since sizeof(int64_t) / sizeof(int32_t) = 2 args_tensor_numel = 4 * group_size + 1 + int(math.ceil(group_size / 2)) - ret.append( - # sizeof(int64_t) = 8, torch.uint8 = at::kByte - input_group[0].new_empty( - args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + # Runtime check for ROCm vs CUDA + is_rocm = torch.version.hip is not None + + if is_rocm: + # ROCm: Allocate dual args_tensors and saved_data tensors + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) ) - ) - ret.append( - # sizeof(int64_t) = 8, torch.uint8 = at::kByte - input_group[0].new_empty( - args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) ) - ) - ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) - ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + else: + # CUDA: Allocate single args_tensor and saved_data tensor + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) + ) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) return ret diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 708312610c..09f5caa870 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1087,8 +1087,11 @@ void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols, - const bool use_small_emb_dim); + const bool use_var_cols +#ifdef USE_ROCM + ,const bool use_small_emb_dim +#endif +); int get_group_index_select_cols_per_warp(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index aa40665d45..4b3b1d0c3f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -37,6 +37,7 @@ int get_group_index_select_unroll_factor() { return GROUP_INDEX_SELECT_UNROLL_FACTOR; } +#ifdef USE_ROCM template < typename index_t, typename scalar_t, @@ -46,6 +47,16 @@ template < int UNROLL_FACTOR, int COLS_PER_WARP, int LOG_COLS_PER_WARP> +#else +template < + typename index_t, + typename scalar_t, + bool USE_INDEX_SELECT, + bool USE_VAR_COLS, + int UNROLL_FACTOR, + int COLS_PER_WARP, + int LOG_COLS_PER_WARP> +#endif __global__ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* input_ptrs, @@ -85,6 +96,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same +#ifdef USE_ROCM if constexpr (USE_SMALL_EMB_DIM) { // Small embedding: pack multiple rows per warp const auto rows_per_warp = COLS_PER_WARP / num_cols; @@ -93,18 +105,22 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_id = warp_id / warps_per_member; member_warp_id = warp_id % warps_per_member; } else { +#endif // Large embedding: one or more warps per row member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); +#ifdef USE_ROCM } +#endif } +#ifdef USE_ROCM if constexpr (USE_SMALL_EMB_DIM) { // Small embedding dimension: pack multiple rows per warp // Each warp processes 'rows_per_warp' rows const auto rows_per_warp = COLS_PER_WARP / num_cols; int64_t start_row = member_warp_id * rows_per_warp; - + // Since we are processing multiple rows within the warp, we need to // map each lane to a specific row, in addition to the column int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp @@ -114,9 +130,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { - scalar_t* input = + scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = + scalar_t* output = reinterpret_cast(output_ptrs[member_id]) + col_offset; index_t* indices = reinterpret_cast(indices_ptrs[member_id]); @@ -133,6 +149,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } } else { +#endif // Large embedding dimension: one or more warps per row const auto row = member_warp_id / warps_per_row; const auto col_offset = @@ -155,7 +172,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( &output[idx * num_cols + i], input[row * num_cols + i]); } } +#ifdef USE_ROCM } +#endif } } @@ -172,8 +191,11 @@ DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols, - const bool use_small_emb_dim) { + const bool use_var_cols +#ifdef USE_ROCM + ,const bool use_small_emb_dim +#endif +) { if (group_size == 0) { return; } @@ -189,6 +211,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda( max_grid_size); dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); +#ifdef USE_ROCM // Kernel launch macro with USE_SMALL_EMB_DIM template parameter #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS, USE_SMALL_EMB_DIM) \ FBGEMM_LAUNCH_KERNEL( \ @@ -212,11 +235,36 @@ DLL_PUBLIC void group_index_select_or_add_cuda( num_cols_group, \ num_work_rows, \ group_size) +#else +// Kernel launch macro for CUDA (no USE_SMALL_EMB_DIM) +#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ + FBGEMM_LAUNCH_KERNEL( \ + (group_index_select_or_add_2d_kernel< \ + index_t, \ + scalar_t, \ + USE_INDEX_SELECT, \ + USE_VAR_COLS, \ + GROUP_INDEX_SELECT_UNROLL_FACTOR, \ + GROUP_INDEX_SELECT_COLS_PER_WARP, \ + GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ + grid_size, \ + block_size, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + input_ptrs, \ + output_ptrs, \ + indices_ptrs, \ + warp_offsets_group, \ + num_cols_group, \ + num_work_rows, \ + group_size) +#endif AT_DISPATCH_INDEX_TYPES( indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( input_scalar_type, "group_index_select_2d_wrapper_2", [&] { +#ifdef USE_ROCM if (use_small_emb_dim) { // Small embedding dimension: pack multiple rows per warp if (use_index_select) { @@ -248,6 +296,22 @@ DLL_PUBLIC void group_index_select_or_add_cuda( } } } +#else + // CUDA: Standard path only (no small embedding optimization) + if (use_index_select) { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false); + } + } else { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false); + } + } +#endif }); }); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 62835a537e..ee8af2499f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3511,7 +3511,11 @@ torch::autograd::variable_list group_index_select_dim0( auto res = forward_op.call( all_indices_input_tensor, static_cast(group_size)); - TORCH_CHECK(res.size() >= group_size + 2); +#ifdef USE_ROCM + TORCH_CHECK(res.size() >= group_size + 4); // ROCm: +4 tensors (2 args, 2 saved_data) +#else + TORCH_CHECK(res.size() >= group_size + 2); // CUDA: +2 tensors (1 args, 1 saved_data) +#endif // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3622,7 +3626,11 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) >= group_size + 2); +#ifdef USE_ROCM + TORCH_CHECK(static_cast(result.size()) >= group_size + 4); // ROCm: +4 tensors +#else + TORCH_CHECK(static_cast(result.size()) >= group_size + 2); // CUDA: +2 tensors +#endif ctx->saved_data["group_size"] = group_size; auto [input_group, indices_group] = @@ -3668,7 +3676,11 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::backward( grad_output_group.resize(group_size); const auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() >= group_size + 3); +#ifdef USE_ROCM + TORCH_CHECK(saved_tensors.size() >= group_size + 5); // ROCm: indices + 4 tensors + fwd_input +#else + TORCH_CHECK(saved_tensors.size() >= group_size + 3); // CUDA: indices + 2 tensors + fwd_input +#endif std::vector output_shape_group; int i = 0; while (true) { diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 3ff429c786..005b19fc0c 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -238,7 +238,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Total number of int64_t elements required args_ptrs_offsets[NUM_ARGS] = offset; - // Allocate memory for GroupIndexSelectArgs +#ifdef USE_ROCM + // Allocate memory for GroupIndexSelectArgs (ROCm: split small/large) at::Tensor args_tensor_small = at::empty( {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, at::TensorOptions().dtype(at::kByte).pinned_memory(true)); @@ -282,6 +283,33 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &large.num_cols_group, reinterpret_cast(args_tensor_large.mutable_data_ptr()), args_ptrs_offsets); +#else + // Allocate memory for GroupIndexSelectArgs (CUDA: unified) + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensor is contiguous + TORCH_CHECK( + args_tensor.is_contiguous(), "Tensor args_tensor must be contiguous."); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.mutable_data_ptr()), + args_ptrs_offsets); +#endif auto& first_input = input_group[0]; auto& first_indices = indices_group[0]; @@ -296,15 +324,20 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( [[maybe_unused]] const int unroll_factor = get_group_index_select_unroll_factor(); +#ifdef USE_ROCM int64_t warp_offset = 0; bool use_var_cols_small = false; bool use_var_cols_large = false; - + bool first_small_table = true; bool first_large_table = true; int prev_num_cols_small; int prev_num_cols_large; +#else + int64_t warp_offset = 0; + bool use_var_cols = false; +#endif // Allocate memory for output_group std::vector output_group; @@ -368,9 +401,9 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warps_needed = warps_per_row * num_output_rows_; } #else - // Standard: One or more warps per row + // CUDA: Standard path only (no small embedding optimization) auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; -#endif // USE_ROCM +#endif // Create output pointers auto input_shape = input.sizes().vec(); @@ -385,6 +418,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.push_back(input.expect_contiguous()); index_contigs.push_back(indices.expect_contiguous()); +#ifdef USE_ROCM if (num_cols_ < cols_per_warp) { // Optimization for Small Embedding: Pack multiple rows per warp @@ -417,8 +451,25 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( large.total_warps += warps_needed; large.count++; } +#else + // CUDA: Store args in unified arrays + if (num_cols != num_cols_) { + use_var_cols = true; + } + + input_ptrs[i] = + reinterpret_cast(input_contigs[i]->const_data_ptr()); + output_ptrs[i] = reinterpret_cast(output.mutable_data_ptr()); + indices_ptrs[i] = + reinterpret_cast(index_contigs[i]->const_data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; + + warp_offset += warps_per_row * num_output_rows; +#endif } +#ifdef USE_ROCM // Store the last offset if (small.count > 0) { small.warp_offsets_group[small.count] = small.total_warps; @@ -432,7 +483,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( args_tensor_small = args_tensor_small.to( first_input.device(), /*non_blocking=*/true); - + // Offset raw ptrs in GPU memory offset_args( &small.input_ptrs, @@ -448,7 +499,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( args_tensor_large = args_tensor_large.to( first_input.device(), /*non_blocking=*/true); - + // Offset raw ptrs in GPU memory offset_args( &large.input_ptrs, @@ -458,8 +509,28 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &large.num_cols_group, reinterpret_cast(args_tensor_large.mutable_data_ptr()), args_ptrs_offsets); - } + } +#else + // Store the last offset + warp_offsets_group[group_size] = warp_offset; + + // Transfer args tensor to GPU + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.mutable_data_ptr()), + args_ptrs_offsets); +#endif +#ifdef USE_ROCM int64_t saved_data_small[] = { static_cast(small.count), use_var_cols_small, @@ -530,28 +601,66 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( output_group.push_back(saved_data_t_large); // return format: - // (group_size outputs, 2 args_tensor2, 2 saved_data) + // (group_size outputs, 2 args_tensors, 2 saved_data) return output_group; +#else + int64_t saved_data[] = { + static_cast(group_size), + use_var_cols, + reinterpret_cast(warp_offsets_group), + reinterpret_cast(num_cols_group), + warp_offset, + }; + auto saved_data_t = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK( + saved_data_t.is_contiguous(), "Tensor saved_data_t must be contiguous."); + memcpy( + saved_data_t.mutable_data_ptr(), saved_data, sizeof(saved_data)); + + group_index_select_or_add_cuda( + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/warp_offset, + group_size, + /*use_index_select=*/true, + use_var_cols); + + output_group.push_back(args_tensor); + output_group.push_back(saved_data_t); + + // return format: + // (group_size outputs, 1 args_tensor, 1 saved_data) + return output_group; +#endif } static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { +#ifdef USE_ROCM TORCH_CHECK_VALUE( all_inputs.size() > 4, "all_inputs size must be larger than 4, but got ", all_inputs.size()); - // all_input size = group_size * 2 (from grads, indices) - // + 2 args_tensor2 + 2 saved_data + 1 first input + // + 2 args_tensors + 2 saved_data + 1 first input const int64_t group_size = (all_inputs.size() - 5) / 2; const Tensor& fwd_input = all_inputs[2 * group_size + 4]; const Tensor& saved_data_small = all_inputs[2 * group_size + 2]; const Tensor& saved_data_large = all_inputs[2 * group_size + 3]; - + const Tensor& first_indices = all_inputs[group_size]; const int64_t output_dim = fwd_input.dim(); @@ -583,7 +692,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const bool use_var_cols_large = saved_data_large_ptr[1]; int64_t* warp_offsets_group_large = reinterpret_cast(saved_data_large_ptr[2]); int32_t* num_cols_group_large = reinterpret_cast(saved_data_large_ptr[3]); - int64_t total_num_warps_large = saved_data_large_ptr[4]; + int64_t total_num_warps_large = saved_data_large_ptr[4]; TORCH_CHECK_VALUE( (count_small + count_large) == group_size, @@ -591,6 +700,52 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( group_size, " but got ", (count_small + count_large)); +#else + TORCH_CHECK_VALUE( + all_inputs.size() > 2, + "all_inputs size must be larger than 2, but got ", + all_inputs.size()); + + // all_input size = group_size * 2 (from grads, indices) + // + 1 args_tensor + 1 saved_data + 1 first input + const int64_t group_size = (all_inputs.size() - 3) / 2; + + const Tensor& fwd_input = all_inputs[2 * group_size + 2]; + const int64_t output_dim = fwd_input.dim(); + const Tensor& saved_data = all_inputs[2 * group_size + 1]; + const Tensor& first_indices = all_inputs[group_size]; + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } + + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK( + saved_data.device() == at::kCPU, "Tensor saved_data must be on CPU."); + TORCH_CHECK( + saved_data.is_contiguous(), "Tensor saved_data must be contiguous."); + const int64_t* saved_data_ptr = saved_data.const_data_ptr(); + // Check that the size is the same + TORCH_CHECK_VALUE( + saved_data_ptr[0] == group_size, + "The size of saved_data[0] must match group_size. Expect ", + group_size, + " but got ", + saved_data_ptr[0]); + const bool use_var_cols = saved_data_ptr[1]; + const int64_t* warp_offsets_group = + reinterpret_cast(saved_data_ptr[2]); + const int32_t* num_cols_group = + reinterpret_cast(saved_data_ptr[3]); + int64_t total_num_warps = saved_data_ptr[4]; +#endif // We checked in forward that all output rows are the same for all member // in the group @@ -613,6 +768,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } +#ifdef USE_ROCM // Allocate Tensors for ptrs of grad output and input, and indices Tensor args_tensor_small = at::empty( {count_small * 3}, @@ -632,6 +788,20 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( int64_t* grad_output_ptrs_large = args_tensor_large.mutable_data_ptr(); int64_t* grad_input_ptrs_large = args_tensor_large.mutable_data_ptr() + count_large; int64_t* indices_ptrs_large = args_tensor_large.mutable_data_ptr() + 2 * count_large; +#else + // Allocate Tensor for ptrs of grad output and input, and indices + Tensor args_tensor = at::empty( + {group_size * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensor is contiguous + TORCH_CHECK( + args_tensor.is_contiguous(), "Tensor args_tensor must be contiguous."); + int64_t* grad_output_ptrs = args_tensor.mutable_data_ptr(); + int64_t* grad_input_ptrs = + args_tensor.mutable_data_ptr() + group_size; + int64_t* indices_ptrs = + args_tensor.mutable_data_ptr() + 2 * group_size; +#endif int64_t group_grad_input_numel = 0; std::vector grad_input_numels; @@ -642,9 +812,11 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( std::vector> grad_output_contigs; grad_output_contigs.reserve(group_size); +#ifdef USE_ROCM const int cols_per_warp = get_group_index_select_cols_per_warp(); int64_t idx_small = 0; int64_t idx_large = 0; +#endif for (const auto i : c10::irange(group_size)) { const auto& grad = grad_output_group[i]; @@ -661,6 +833,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_numels.push_back(grad_input_numel); group_grad_input_numel += grad_input_numel; +#ifdef USE_ROCM const auto grad_reshaped = grad.reshape({grad.size(0), -1}); const auto num_cols_ = grad_reshaped.size(1); @@ -674,6 +847,11 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( reinterpret_cast(grad_output_contigs[i]->const_data_ptr()); idx_large++; } +#else + // Put all grad output/input pointers in an array + grad_output_ptrs[i] = + reinterpret_cast(grad_output_contigs[i]->const_data_ptr()); +#endif } // Allocate a big tensor to avoid calling many small elementwise kernels @@ -693,9 +871,11 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( " but got ", output_group.size()); +#ifdef USE_ROCM // Reset the counters of the small and large arguments idx_small = 0; idx_large = 0; +#endif // Reshape grad inputs and obtain their pointers for (int i = 0; i < group_size; i++) { @@ -712,6 +892,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( group_size, " must be contiguous."); +#ifdef USE_ROCM const auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); const auto num_cols_ = output_group_reshaped.size(1); if(num_cols_ < cols_per_warp) { @@ -721,14 +902,20 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_ptrs_large[idx_large] = reinterpret_cast(output_group[i].const_data_ptr()); idx_large++; } +#else + grad_input_ptrs[i] = + reinterpret_cast(output_group[i].const_data_ptr()); +#endif // 2) Add group_size gradients for inputs outputs.push_back(output_group[i]); } +#ifdef USE_ROCM // Reset the counters of the small and large arguments idx_small = 0; idx_large = 0; +#endif // Calculate indices_ptrs std::vector> index_contigs; @@ -736,7 +923,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( for (const auto i : c10::irange(group_size)) { const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); - + +#ifdef USE_ROCM const auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); const auto num_cols_ = grad_output_reshaped.size(1); if(num_cols_ < cols_per_warp) { @@ -746,8 +934,13 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( indices_ptrs_large[idx_large] = reinterpret_cast(index_contigs[i]->const_data_ptr()); idx_large++; } +#else + indices_ptrs[i] = + reinterpret_cast(index_contigs[i]->const_data_ptr()); +#endif } +#ifdef USE_ROCM // Transfer grad output pointers to GPU args_tensor_small = args_tensor_small.to(first_indices.device(), /*non_blocking=*/true); args_tensor_large = args_tensor_large.to(first_indices.device(), /*non_blocking=*/true); @@ -787,6 +980,25 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( use_var_cols_large, /*use_small_emb_dim=*/false); } +#else + // Transfer grad output pointers to GPU + args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); + + group_index_select_or_add_cuda( + args_tensor.const_data_ptr(), + args_tensor.const_data_ptr() + group_size, + args_tensor.const_data_ptr() + 2 * group_size, + warp_offsets_group, + num_cols_group, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps, + group_size, + /*use_index_select=*/false, + use_var_cols); +#endif return outputs; } From 63fa242c1259de1ff5bf4c4edefd485190d5d414 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 2 Feb 2026 15:50:37 +0000 Subject: [PATCH 25/25] sparse_ops_cpu.cpp: corrects TORCH_CHECK commands --- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index ee8af2499f..cd0ef6dc7b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3512,9 +3512,9 @@ torch::autograd::variable_list group_index_select_dim0( auto res = forward_op.call( all_indices_input_tensor, static_cast(group_size)); #ifdef USE_ROCM - TORCH_CHECK(res.size() >= group_size + 4); // ROCm: +4 tensors (2 args, 2 saved_data) + TORCH_CHECK(res.size() >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4) #else - TORCH_CHECK(res.size() >= group_size + 2); // CUDA: +2 tensors (1 args, 1 saved_data) + TORCH_CHECK(res.size() == group_size + 2); // CUDA: +2 tensors (1 args, 1 saved_data) #endif // only return the outputs (the first group_size elements) res.resize(group_size); @@ -3627,9 +3627,9 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .typed(); auto result = forward_op.call(all_indices_input, group_size); #ifdef USE_ROCM - TORCH_CHECK(static_cast(result.size()) >= group_size + 4); // ROCm: +4 tensors + TORCH_CHECK(static_cast(result.size()) >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4) #else - TORCH_CHECK(static_cast(result.size()) >= group_size + 2); // CUDA: +2 tensors + TORCH_CHECK(static_cast(result.size()) == group_size + 2); // CUDA: +2 tensors #endif ctx->saved_data["group_size"] = group_size; @@ -3677,9 +3677,9 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::backward( const auto saved_tensors = ctx->get_saved_variables(); #ifdef USE_ROCM - TORCH_CHECK(saved_tensors.size() >= group_size + 5); // ROCm: indices + 4 tensors + fwd_input + TORCH_CHECK(saved_tensors.size() >= group_size + 3); // ROCm: >= to handle both CPU (+3) and GPU (+5) #else - TORCH_CHECK(saved_tensors.size() >= group_size + 3); // CUDA: indices + 2 tensors + fwd_input + TORCH_CHECK(saved_tensors.size() == group_size + 3); // CUDA: indices + 2 tensors + fwd_input #endif std::vector output_shape_group; int i = 0;