-
Notifications
You must be signed in to change notification settings - Fork 9
Split Kernel Optimization for group_index_select_or_add_2d_kernel
#142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main_12162025_upstream
Are you sure you want to change the base?
Changes from all commits
85caa29
e4f1dba
ff1b9b6
439a51a
2a85d73
2f54140
e0edc40
81bf648
93b5a2e
17d8d4c
b6cec91
4576e59
71be9d2
e5dcf52
f24a6cd
39afe28
10b692d
b9ae864
c54af4f
65a3f84
ab14c59
598b2de
6a3fef1
554466f
adb728e
9310b72
63fa242
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,14 +37,26 @@ int get_group_index_select_unroll_factor() { | |
| return GROUP_INDEX_SELECT_UNROLL_FACTOR; | ||
| } | ||
|
|
||
| #ifdef USE_ROCM | ||
| template < | ||
| typename index_t, | ||
| typename scalar_t, | ||
| bool USE_INDEX_SELECT, | ||
| bool USE_VAR_COLS, | ||
| bool USE_SMALL_EMB_DIM, | ||
| int UNROLL_FACTOR, | ||
| int COLS_PER_WARP, | ||
| int LOG_COLS_PER_WARP> | ||
| #else | ||
| template < | ||
| typename index_t, | ||
| typename scalar_t, | ||
| bool USE_INDEX_SELECT, | ||
| bool USE_VAR_COLS, | ||
| int UNROLL_FACTOR, | ||
| int COLS_PER_WARP, | ||
| int LOG_COLS_PER_WARP> | ||
| #endif | ||
| __global__ | ||
| __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( | ||
| const int64_t* input_ptrs, | ||
|
|
@@ -84,39 +96,39 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( | |
| member_warp_id = warp_id - warp_offsets_group[member_id]; | ||
| } else { | ||
| // All columns are the same | ||
| member_id = warp_id / (warps_per_row * num_work_rows); | ||
| member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); | ||
| #ifdef USE_ROCM | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need to wrap it in |
||
| if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { | ||
| // Need to ensure that [member_id] and [member_warp_id] are calculated | ||
| // correctly for the small embedding dimension path below | ||
| if constexpr (USE_SMALL_EMB_DIM) { | ||
| // Small embedding: pack multiple rows per warp | ||
| const auto rows_per_warp = COLS_PER_WARP / num_cols; | ||
| const auto warps_per_member = | ||
| DIV_ROUND_UP(num_work_rows, rows_per_warp); | ||
| member_id = warp_id / warps_per_member; | ||
| member_warp_id = warp_id % warps_per_member; | ||
| } else { | ||
| #endif | ||
| // Large embedding: one or more warps per row | ||
| member_id = warp_id / (warps_per_row * num_work_rows); | ||
| member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); | ||
| #ifdef USE_ROCM | ||
| } | ||
| #endif // USE_ROCM | ||
| #endif | ||
| } | ||
|
|
||
| #ifdef USE_ROCM | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as in https://github.com/ROCm/FBGEMM/pull/142/changes#r2758719145 |
||
| if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { | ||
| // Optimized path for small embedding dimensions | ||
| if constexpr (USE_SMALL_EMB_DIM) { | ||
| // Small embedding dimension: pack multiple rows per warp | ||
| // Each warp processes 'rows_per_warp' rows | ||
| const auto rows_per_warp = COLS_PER_WARP / num_cols; | ||
| const int64_t start_row = member_warp_id * rows_per_warp; | ||
| int64_t start_row = member_warp_id * rows_per_warp; | ||
|
|
||
| // Since we are processing multiple rows within the warp, we need to | ||
| // map each lane to a specific row, in addition to the column | ||
| const auto local_row = (threadIdx.x * UNROLL_FACTOR) / | ||
| num_cols; // the row ID within the set of rows handled by this warp | ||
| const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; | ||
| const int64_t current_row = start_row + | ||
| local_row; // the actual row within the table processed by this lane | ||
|
|
||
| // local_row may be out of bounds for the last few lanes in the warp if | ||
| // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are | ||
| // within num_work_rows | ||
| int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp | ||
| int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; | ||
| int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane | ||
|
|
||
| // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] | ||
| // and we also need to confirm that we are within num_work_rows | ||
|
Comment on lines
+122
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did we removed |
||
| if (local_row < rows_per_warp && current_row < num_work_rows) { | ||
| scalar_t* input = | ||
| reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset; | ||
|
|
@@ -129,18 +141,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( | |
| for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { | ||
| // Compile time conditional | ||
| if constexpr (USE_INDEX_SELECT) { | ||
| output[current_row * num_cols + i] = | ||
| LDG(&input[idx * num_cols + i]); | ||
| output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); | ||
| } else { | ||
| gpuAtomicAddNoReturn( | ||
| &output[idx * num_cols + i], input[current_row * num_cols + i]); | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| // Large embedding dimensions use >= 1 warp per row | ||
| // which is the default codepath for non-ROCm as well | ||
| #endif // USE_ROCM | ||
| #endif | ||
| // Large embedding dimension: one or more warps per row | ||
| const auto row = member_warp_id / warps_per_row; | ||
| const auto col_offset = | ||
| ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + | ||
|
|
@@ -164,7 +174,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( | |
| } | ||
| #ifdef USE_ROCM | ||
| } | ||
| #endif // USE_ROCM | ||
| #endif | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -181,7 +191,11 @@ DLL_PUBLIC void group_index_select_or_add_cuda( | |
| const int64_t total_num_warps, | ||
| const int group_size, | ||
| const bool use_index_select, | ||
| const bool use_var_cols) { | ||
| const bool use_var_cols | ||
| #ifdef USE_ROCM | ||
| ,const bool use_small_emb_dim | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may force |
||
| #endif | ||
| ) { | ||
| if (group_size == 0) { | ||
| return; | ||
| } | ||
|
|
@@ -197,6 +211,32 @@ DLL_PUBLIC void group_index_select_or_add_cuda( | |
| max_grid_size); | ||
| dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); | ||
|
|
||
| #ifdef USE_ROCM | ||
| // Kernel launch macro with USE_SMALL_EMB_DIM template parameter | ||
| #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS, USE_SMALL_EMB_DIM) \ | ||
| FBGEMM_LAUNCH_KERNEL( \ | ||
| (group_index_select_or_add_2d_kernel< \ | ||
| index_t, \ | ||
| scalar_t, \ | ||
| USE_INDEX_SELECT, \ | ||
| USE_VAR_COLS, \ | ||
| USE_SMALL_EMB_DIM, \ | ||
| GROUP_INDEX_SELECT_UNROLL_FACTOR, \ | ||
| GROUP_INDEX_SELECT_COLS_PER_WARP, \ | ||
| GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ | ||
| grid_size, \ | ||
| block_size, \ | ||
| 0, \ | ||
| at::cuda::getCurrentCUDAStream(), \ | ||
| input_ptrs, \ | ||
| output_ptrs, \ | ||
| indices_ptrs, \ | ||
| warp_offsets_group, \ | ||
| num_cols_group, \ | ||
| num_work_rows, \ | ||
| group_size) | ||
| #else | ||
| // Kernel launch macro for CUDA (no USE_SMALL_EMB_DIM) | ||
| #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ | ||
| FBGEMM_LAUNCH_KERNEL( \ | ||
| (group_index_select_or_add_2d_kernel< \ | ||
|
|
@@ -218,11 +258,46 @@ DLL_PUBLIC void group_index_select_or_add_cuda( | |
| num_cols_group, \ | ||
| num_work_rows, \ | ||
| group_size) | ||
| #endif | ||
|
|
||
| AT_DISPATCH_INDEX_TYPES( | ||
| indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { | ||
| FBGEMM_DISPATCH_FLOATING_TYPES( | ||
| input_scalar_type, "group_index_select_2d_wrapper_2", [&] { | ||
| #ifdef USE_ROCM | ||
| if (use_small_emb_dim) { | ||
| // Small embedding dimension: pack multiple rows per warp | ||
| if (use_index_select) { | ||
| if (use_var_cols) { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, true); | ||
| } else { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, true); | ||
| } | ||
| } else { | ||
| if (use_var_cols) { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, true); | ||
| } else { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, true); | ||
| } | ||
| } | ||
| } else { | ||
| // Large embedding dimension: one or more warps per row | ||
| if (use_index_select) { | ||
| if (use_var_cols) { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, false); | ||
| } else { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, false); | ||
| } | ||
| } else { | ||
| if (use_var_cols) { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, false); | ||
| } else { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, false); | ||
| } | ||
| } | ||
| } | ||
| #else | ||
| // CUDA: Standard path only (no small embedding optimization) | ||
| if (use_index_select) { | ||
| if (use_var_cols) { | ||
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true); | ||
|
|
@@ -236,6 +311,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda( | |
| INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false); | ||
| } | ||
| } | ||
| #endif | ||
| }); | ||
| }); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3508,9 +3508,14 @@ torch::autograd::variable_list group_index_select_dim0( | |
| at::Dispatcher::singleton() | ||
| .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") | ||
| .typed<decltype(group_index_select_dim0_autograd_impl)>(); | ||
|
|
||
| auto res = forward_op.call( | ||
| all_indices_input_tensor, static_cast<int64_t>(group_size)); | ||
| TORCH_CHECK(res.size() == group_size + 2); | ||
| all_indices_input_tensor, static_cast<int64_t>(group_size)); | ||
| #ifdef USE_ROCM | ||
| TORCH_CHECK(res.size() >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have an exact comparison against expected size since |
||
| #else | ||
| TORCH_CHECK(res.size() == group_size + 2); // CUDA: +2 tensors (1 args, 1 saved_data) | ||
| #endif | ||
| // only return the outputs (the first group_size elements) | ||
| res.resize(group_size); | ||
| return res; | ||
|
|
@@ -3621,7 +3626,12 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( | |
| .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") | ||
| .typed<decltype(group_index_select_dim0_forward_impl_cpu)>(); | ||
| auto result = forward_op.call(all_indices_input, group_size); | ||
| TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2); | ||
| #ifdef USE_ROCM | ||
| TORCH_CHECK(static_cast<int64_t>(result.size()) >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as in https://github.com/ROCm/FBGEMM/pull/142/changes#r2758749386 |
||
| #else | ||
| TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2); // CUDA: +2 tensors | ||
| #endif | ||
| ctx->saved_data["group_size"] = group_size; | ||
|
|
||
| auto [input_group, indices_group] = | ||
| group_index_select_dim0_unpack(all_indices_input, group_size); | ||
|
|
@@ -3654,17 +3664,23 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( | |
| torch::autograd::variable_list GroupIndexSelectDim0Op::backward( | ||
| torch::autograd::AutogradContext* ctx, | ||
| torch::autograd::variable_list grad_output_group) { | ||
| TORCH_CHECK(grad_output_group.size() >= 2); | ||
| if (grad_output_group.size() == 2) { | ||
|
|
||
| auto group_size = ctx->saved_data["group_size"].toInt(); | ||
| TORCH_CHECK(static_cast<int64_t>(grad_output_group.size()) >= group_size); | ||
|
|
||
| if (group_size == 0) { | ||
| // empty outputs | ||
| return torch::autograd::variable_list(1); | ||
| } | ||
| // remove redundant grads | ||
| auto group_size = grad_output_group.size() - 2; | ||
| grad_output_group.resize(group_size); | ||
|
|
||
| const auto saved_tensors = ctx->get_saved_variables(); | ||
| TORCH_CHECK(saved_tensors.size() == group_size + 3); | ||
| #ifdef USE_ROCM | ||
| TORCH_CHECK(saved_tensors.size() >= group_size + 3); // ROCm: >= to handle both CPU (+3) and GPU (+5) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as in https://github.com/ROCm/FBGEMM/pull/142/changes#r2758749386 |
||
| #else | ||
| TORCH_CHECK(saved_tensors.size() == group_size + 3); // CUDA: indices + 2 tensors + fwd_input | ||
| #endif | ||
| std::vector<c10::SymInt> output_shape_group; | ||
| int i = 0; | ||
| while (true) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed this offline. I believe that code duplication might be reduced by having unified kernel template for both Nvidia and ROCm paths, and using
std::visitwith platform specificstd::variantas in https://github.com/ROCm/FBGEMM/pull/139/changes#diff-6f509196a8893b5345f5e615251ce85ea5f575b81c1e9136fff764a899d92562R329-R407. This way we won't need to duplicate invoke macro and compilation time/binary size will be aligned with expectations for each platform. It will also make adding new template parameters easier.