Skip to content

Implement pre-sorting, caching and contigous warp processing in group_index_select#144

Open
avbokovoy wants to merge 1 commit intoabokovoi/upstreamfrom
abokovoi/group-index-sort-and-cache-opt
Open

Implement pre-sorting, caching and contigous warp processing in group_index_select#144
avbokovoy wants to merge 1 commit intoabokovoi/upstreamfrom
abokovoi/group-index-sort-and-cache-opt

Conversation

@avbokovoy
Copy link

Follow-up of #139

The differences are:

  1. Reduced #ifdef USE_ROCM usage in favor of if constexpr (OPT_BOOL).
  2. Added compile-time host side codegen guard for the kernel (CUDA vs ROCm)
  3. Fixed an issue with tailing row cache flush

@avbokovoy avbokovoy self-assigned this Mar 3, 2026
Copy link

@aryaman-gupta aryaman-gupta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR introduces crucial optimizations for the group_index_select_or_add_2d_kernel. The majority of the code is clean and the separation of ROCm and CUDA codepaths has been done well.

Most of these changes were already reviewed in #139 . I have left a few comments that I think should be looked at before merging. Some of these are design choices, and the PR could proceed with merging even if the code is not modified,

auto sorted_indices = at::empty_like(contiguous_indices);
auto reverse_indices = at::empty(
contiguous_indices.sizes(),
contiguous_indices.options().dtype(at::kLong));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be at::kInt as we had previously discussed. Since reverse_indices tracks the positions of the elements, and the number of elements is limited to int::max.

Same with original_positions below.

auto res = forward_op.call(
all_indices_input_tensor, static_cast<int64_t>(group_size));
TORCH_CHECK(res.size() == group_size + 2);
TORCH_CHECK(res.size() == group_size + 4);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed previously that we could keep the CPU pass unchanged, and therefore change this condition to TORCH_CHECK(res.size() >= group_size + 2);. Did you change your mind about that?

// to match return format in CUDA implementation
// (group_size outputs, 1 args_tensor, 1 saved_data)
// (group_size outputs, 1 args_tensor, 1 saved_data, 1 sorted tensor, 1 reverse tensor)
output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong)));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the previous comment, if we decide to keep the CPU path unchanged, these additional tensors should be removed.

// all input size = group_size * 2 (from grads, indices)
// + 1 args_tensor + 1 saved_data + 1 first output
const int64_t group_size = static_cast<int64_t>((all_inputs.size() - 3) / 2);
const int64_t group_size = static_cast<int64_t>((all_inputs.size() - 5) / 2);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following on the above comments, group_size could then be saved between forward and backward passes as:

ctx->saved_data["group_size"] = group_size;

int64_t warp_offset = 0;
bool use_var_cols = false;

Tensor sorted_indices_storage =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these tensors in the forward pass when the sorting actually takes place in the backward pass?

Comment on lines +532 to +533
int64_t* warp_offsets_group = reinterpret_cast<int64_t*>(saved_data_ptr[4]);
int32_t* num_cols_group = reinterpret_cast<int32_t*>(saved_data_ptr[5]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be const int64_t*

Comment on lines +155 to +171
case at::ScalarType::Byte:
dispatch(uint8_t{});
break;
case at::ScalarType::Char:
dispatch(int8_t{});
break;
case at::ScalarType::Short:
dispatch(int16_t{});
break;
case at::ScalarType::Int:
dispatch(int32_t{});
break;
case at::ScalarType::Long:
dispatch(int64_t{});
break;
default:
TORCH_CHECK(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation could be added as:

      case at::ScalarType::Byte: {
          dispatch(uint8_t{});
          break;
      }

false));
};

switch (scalar_type) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a PyTorch macro like AT_DISPATCH_INTEGRAL_TYPES be used here?

for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
if constexpr (USE_CACHE) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_CACHE is always false for the forward pass. Why not simplify the code by removing the condition?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants