Implement pre-sorting, caching and contigous warp processing in group_index_select#144
Implement pre-sorting, caching and contigous warp processing in group_index_select#144avbokovoy wants to merge 1 commit intoabokovoi/upstreamfrom
Conversation
…_index_select kernel
aryaman-gupta
left a comment
There was a problem hiding this comment.
The PR introduces crucial optimizations for the group_index_select_or_add_2d_kernel. The majority of the code is clean and the separation of ROCm and CUDA codepaths has been done well.
Most of these changes were already reviewed in #139 . I have left a few comments that I think should be looked at before merging. Some of these are design choices, and the PR could proceed with merging even if the code is not modified,
| auto sorted_indices = at::empty_like(contiguous_indices); | ||
| auto reverse_indices = at::empty( | ||
| contiguous_indices.sizes(), | ||
| contiguous_indices.options().dtype(at::kLong)); |
There was a problem hiding this comment.
This should be at::kInt as we had previously discussed. Since reverse_indices tracks the positions of the elements, and the number of elements is limited to int::max.
Same with original_positions below.
| auto res = forward_op.call( | ||
| all_indices_input_tensor, static_cast<int64_t>(group_size)); | ||
| TORCH_CHECK(res.size() == group_size + 2); | ||
| TORCH_CHECK(res.size() == group_size + 4); |
There was a problem hiding this comment.
We discussed previously that we could keep the CPU pass unchanged, and therefore change this condition to TORCH_CHECK(res.size() >= group_size + 2);. Did you change your mind about that?
| // to match return format in CUDA implementation | ||
| // (group_size outputs, 1 args_tensor, 1 saved_data) | ||
| // (group_size outputs, 1 args_tensor, 1 saved_data, 1 sorted tensor, 1 reverse tensor) | ||
| output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); |
There was a problem hiding this comment.
As in the previous comment, if we decide to keep the CPU path unchanged, these additional tensors should be removed.
| // all input size = group_size * 2 (from grads, indices) | ||
| // + 1 args_tensor + 1 saved_data + 1 first output | ||
| const int64_t group_size = static_cast<int64_t>((all_inputs.size() - 3) / 2); | ||
| const int64_t group_size = static_cast<int64_t>((all_inputs.size() - 5) / 2); |
There was a problem hiding this comment.
Following on the above comments, group_size could then be saved between forward and backward passes as:
| int64_t warp_offset = 0; | ||
| bool use_var_cols = false; | ||
|
|
||
| Tensor sorted_indices_storage = |
There was a problem hiding this comment.
Do we need these tensors in the forward pass when the sorting actually takes place in the backward pass?
| int64_t* warp_offsets_group = reinterpret_cast<int64_t*>(saved_data_ptr[4]); | ||
| int32_t* num_cols_group = reinterpret_cast<int32_t*>(saved_data_ptr[5]); |
| case at::ScalarType::Byte: | ||
| dispatch(uint8_t{}); | ||
| break; | ||
| case at::ScalarType::Char: | ||
| dispatch(int8_t{}); | ||
| break; | ||
| case at::ScalarType::Short: | ||
| dispatch(int16_t{}); | ||
| break; | ||
| case at::ScalarType::Int: | ||
| dispatch(int32_t{}); | ||
| break; | ||
| case at::ScalarType::Long: | ||
| dispatch(int64_t{}); | ||
| break; | ||
| default: | ||
| TORCH_CHECK( |
There was a problem hiding this comment.
Indentation could be added as:
case at::ScalarType::Byte: {
dispatch(uint8_t{});
break;
}
| false)); | ||
| }; | ||
|
|
||
| switch (scalar_type) { |
There was a problem hiding this comment.
Can a PyTorch macro like AT_DISPATCH_INTEGRAL_TYPES be used here?
| for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { | ||
| // Compile time conditional | ||
| if constexpr (USE_INDEX_SELECT) { | ||
| if constexpr (USE_CACHE) { |
There was a problem hiding this comment.
USE_CACHE is always false for the forward pass. Why not simplify the code by removing the condition?
Follow-up of #139
The differences are:
#ifdef USE_ROCMusage in favor ofif constexpr (OPT_BOOL).