Implement asynchronous LDS loads for MI350#138
Implement asynchronous LDS loads for MI350#138avbokovoy wants to merge 3 commits intoabokovoi/upstreamfrom
Conversation
aryaman-gupta
left a comment
There was a problem hiding this comment.
I have taken a look at the PR and identified a couple of areas that looked tricky to me. It may make sense to double-check these to confirm that the logic is correct
| __builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr)); | ||
| const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; | ||
| asm volatile("s_mov_b32 m0, %0\n" | ||
| "global_load_lds_dwordx4 %1, off\n" ::"s"(smem), |
There was a problem hiding this comment.
This enforces that the entire warp will load a contiguous chunk of memory from global to LDS. What happens when the row is not large enough, i.e., kWarpSize > NumUint4LoadsPerRow? As I understand it, this would assign different row_load_idx to different lanes in the wavefront
You might want to confirm that this case is correctly handled
There was a problem hiding this comment.
It will load 16 bytes (16 x 64 for the whole wave) from corresponding vector register (address is different from lane to lane) into LDS pointer with corresponding strides. Global memory doesn't have to be contiguous. The sanity of the loads are checked outside of this function and is handled with pred_guard. Tailing or OOB loads are redirected to zero_tile global memory chunk, which contains zeroes. It's then handled properly by kernel logic
|
|
||
| asm volatile("cp.async.wait_all;\n" ::); | ||
| #elif defined(USE_ROCM) && \ | ||
| (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) |
There was a problem hiding this comment.
Does this mean that there is no wait instruction defined for ROCm version >= 7.2? Just wanted to confirm that this is intentional
There was a problem hiding this comment.
It should be handled by intrinsic, which assumes that this functions has a side effect
| cp_async_zfill_cg<sizeof(uint4)>( | ||
| &buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], | ||
| &row_v[inner_i][row_load_idx], | ||
| final_valid); |
There was a problem hiding this comment.
It seems to me that in PackedMode, the smem_ptr passed to the different lanes in the cp_async_zfill_cg function is strided. However, the cp_async_zfill_cg function uses lane 0's smem_ptr and performs a contiguous memory read into that location. This seems suspicious to me, so I wanted to point it out. I suppose you have verified that the logic is correct @avbokovoy ?
There was a problem hiding this comment.
I guess this comment applies here as well:
#138 (comment)
This PR implements direct HBM->LDS stores in tbe inference kernel. There are 2 major changes:
Due to pre-7.2 ROCm features, we are forced to used assembly inline to get 16B loads to work, so manual synchronization was added. In case of ROCm >= 7.2, we use proper intrinsics to handle memory synchronization.
This change brings ~10% performance boost on average for weighted and unweighted cases. We may try to push it further by doing async loads for indices weights.
cc: @amirakb89 you might be interested