Skip to content

Implement asynchronous LDS loads for MI350#138

Open
avbokovoy wants to merge 3 commits intoabokovoi/upstreamfrom
abokovoi/async-lds-inference-opt
Open

Implement asynchronous LDS loads for MI350#138
avbokovoy wants to merge 3 commits intoabokovoi/upstreamfrom
abokovoi/async-lds-inference-opt

Conversation

@avbokovoy
Copy link

This PR implements direct HBM->LDS stores in tbe inference kernel. There are 2 major changes:

  1. Rows data isn't loaded in-place, instead we store pointers to global memory and store the actual data w.r.t. the predicate into LDS. In case predicate is false, we pre-allocate small chunk of static device memory of 16B once, fill it with zeros, and fallback to this chunk
  2. HBM->LDS 16B loads are implemented for ROCm >= 7.0 and MI350. We can expand the support range to MI30* through 4B loads, however it doesn't bring any performance benefits because we'll have to introduce an overhead of addresses transposition and 4x more load operations. You can find out the reference implementation here: pytorch@fe52557.

Due to pre-7.2 ROCm features, we are forced to used assembly inline to get 16B loads to work, so manual synchronization was added. In case of ROCm >= 7.2, we use proper intrinsics to handle memory synchronization.

This change brings ~10% performance boost on average for weighted and unweighted cases. We may try to push it further by doing async loads for indices weights.

cc: @amirakb89 you might be interested

@avbokovoy avbokovoy self-assigned this Dec 19, 2025
@avbokovoy avbokovoy added the enhancement New feature or request label Dec 19, 2025
Copy link

@aryaman-gupta aryaman-gupta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have taken a look at the PR and identified a couple of areas that looked tricky to me. It may make sense to double-check these to confirm that the logic is correct

__builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr));
const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile;
asm volatile("s_mov_b32 m0, %0\n"
"global_load_lds_dwordx4 %1, off\n" ::"s"(smem),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enforces that the entire warp will load a contiguous chunk of memory from global to LDS. What happens when the row is not large enough, i.e., kWarpSize > NumUint4LoadsPerRow? As I understand it, this would assign different row_load_idx to different lanes in the wavefront

uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow;

You might want to confirm that this case is correctly handled

Copy link
Author

@avbokovoy avbokovoy Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will load 16 bytes (16 x 64 for the whole wave) from corresponding vector register (address is different from lane to lane) into LDS pointer with corresponding strides. Global memory doesn't have to be contiguous. The sanity of the loads are checked outside of this function and is handled with pred_guard. Tailing or OOB loads are redirected to zero_tile global memory chunk, which contains zeroes. It's then handled properly by kernel logic


asm volatile("cp.async.wait_all;\n" ::);
#elif defined(USE_ROCM) && \
(ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that there is no wait instruction defined for ROCm version >= 7.2? Just wanted to confirm that this is intentional

Copy link
Author

@avbokovoy avbokovoy Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be handled by intrinsic, which assumes that this functions has a side effect

Comment on lines +203 to +206
cp_async_zfill_cg<sizeof(uint4)>(
&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx],
&row_v[inner_i][row_load_idx],
final_valid);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that in PackedMode, the smem_ptr passed to the different lanes in the cp_async_zfill_cg function is strided. However, the cp_async_zfill_cg function uses lane 0's smem_ptr and performs a contiguous memory read into that location. This seems suspicious to me, so I wanted to point it out. I suppose you have verified that the logic is correct @avbokovoy ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this comment applies here as well:
#138 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants