Optimizations for index_select_scalar_cumsum_kernel#137
Optimizations for index_select_scalar_cumsum_kernel#137amd-wsung102 wants to merge 16 commits intowill/upstreamfrom
Conversation
avbokovoy
left a comment
There was a problem hiding this comment.
Minor tweaks are needed, but overall LGTM
| auto grid_size = cuda_calc_xblock_count( | ||
| int grid_size = 0; | ||
| #ifdef USE_ROCM | ||
| constexpr int VEC = 4; |
There was a problem hiding this comment.
We can pass that to the kernel as a template parameter with default value for easier tweaking if needed. Also VEC variable name is not self-explanatory
There was a problem hiding this comment.
I changed the name of VEC to ENTRIES_PER_THREAD so it sounds more intuitive. It was originally passed to the kernel as a new template parameter, but Li said to avoid changing the template and kernel API. Should I change it back to a template parameter?
There was a problem hiding this comment.
ENTRIES_PER_THREAD (previously VEC) is now passed to the kernel as a template parameter.
There was a problem hiding this comment.
@liligwu Could you elaborate why we should avoid changing template and kernel API?
fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
Outdated
Show resolved
Hide resolved
fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
Outdated
Show resolved
Hide resolved
aryaman-gupta
left a comment
There was a problem hiding this comment.
I took a glance at the code and left a couple of comments. Looks good to me otherwise
| } | ||
|
|
||
| // Faster path for single block | ||
| if (!multi_block) { |
There was a problem hiding this comment.
As in the other file, you may consider passing multi_block as a compile-time parameter or splitting the function and dispatching the appropriate one at runtime.
There was a problem hiding this comment.
Got it, I will be sure to try it to test its results.
Optimization Changes
#ifdef USE_ROCMTest Result
Reduced the duration of index_select_scalar_cumsum_kernel by 1.11 us, yielding a 1.3x speedup.
Test Plan
Unit test passed
Submission Checklist