-
Notifications
You must be signed in to change notification settings - Fork 9
Implement asynchronous LDS loads for MI350 #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: abokovoi/upstream
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -170,6 +170,10 @@ __device__ __forceinline__ void cp_async_wait() { | |||
| #if __CUDA_ARCH__ >= 800 | ||||
|
|
||||
| asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); | ||||
| #elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \ | ||||
| (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__) | ||||
|
|
||||
| __builtin_amdgcn_s_waitcnt(0); | ||||
| #endif | ||||
| } | ||||
|
|
||||
|
|
@@ -179,13 +183,27 @@ __device__ __forceinline__ void cp_async_wait<0>() { | |||
| #if __CUDA_ARCH__ >= 800 | ||||
|
|
||||
| asm volatile("cp.async.wait_all;\n" ::); | ||||
| #elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \ | ||||
| (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__) | ||||
|
|
||||
| __builtin_amdgcn_s_waitcnt(0); | ||||
| #endif | ||||
| } | ||||
|
|
||||
| template<typename T> | ||||
| __device__ __forceinline__ uint32_t hip_cvta_to_shared_address(const T* ptr) { | ||||
| // First get the address as a size_t to handle all pointer sizes | ||||
| size_t addr = reinterpret_cast<size_t>(ptr); | ||||
|
|
||||
| // Extract the lower 32 bits which represent the shared memory offset | ||||
| // This is safe because shared memory addresses are always within 32-bit range | ||||
| return static_cast<uint32_t>(addr & 0xFFFFFFFF); | ||||
| } | ||||
|
|
||||
| /// Partial specialization | ||||
| template <int SizeInBytes> | ||||
| __device__ __forceinline__ void | ||||
| cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) { | ||||
| cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_guard) { | ||||
| #if __CUDA_ARCH__ >= 800 | ||||
| static_assert( | ||||
| SizeInBytes == 16, | ||||
|
|
@@ -199,6 +217,36 @@ cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) { | |||
| "n"(SizeInBytes), | ||||
| "r"(src_in_bytes)); | ||||
|
|
||||
| // if ROCm version >= 7.2 and MI350 | ||||
| #elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR > 7 || \ | ||||
| (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR >= 2)) && defined(__gfx950__) | ||||
| static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; | ||||
| static_assert( | ||||
| SizeInBytes == 16, | ||||
| "cp_async_zfill_cg() function is implemented for 16B inputs only"); | ||||
| // Due to LLVM bug, we can't use SizeInBytes directly | ||||
| // in __builtin_amdgcn_global_load_lds intrinsic until | ||||
| // ROCm 7.11: | ||||
| // https://github.com/llvm/llvm-project/pull/175767 | ||||
| // | ||||
| // Make sure you modify this #if branch if SizeInBytes | ||||
| // support range is extended | ||||
| const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; | ||||
| __builtin_amdgcn_global_load_lds(const_cast<void*>(src_ptr), smem_ptr, 16, 0, 0); | ||||
| // if MI350 | ||||
| #elif defined(USE_ROCM) && defined(__gfx950__) | ||||
| static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; | ||||
| static_assert( | ||||
| SizeInBytes == 16, | ||||
| "cp_async_zfill_cg() function is implemented for 16B inputs only"); | ||||
|
|
||||
| uint32_t smem = | ||||
| __builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr)); | ||||
| const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; | ||||
| asm volatile("s_mov_b32 m0, %0\n" | ||||
| "global_load_lds_dwordx4 %1, off\n" ::"s"(smem), | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This enforces that the entire warp will load a contiguous chunk of memory from global to LDS. What happens when the row is not large enough, i.e., FBGEMM/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu Line 153 in 856f1af
You might want to confirm that this case is correctly handled
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will load 16 bytes (16 x 64 for the whole wave) from corresponding vector register (address is different from lane to lane) into LDS pointer with corresponding strides. Global memory doesn't have to be contiguous. The sanity of the loads are checked outside of this function and is handled with |
||||
| "v"(static_cast<const uint32_t *>(src_ptr)) | ||||
| :); | ||||
| #else | ||||
| static_assert(SizeInBytes == 16, ""); | ||||
| using AccessType = uint4; | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that in
PackedMode, thesmem_ptrpassed to the different lanes in thecp_async_zfill_cgfunction is strided. However, thecp_async_zfill_cgfunction uses lane 0'ssmem_ptrand performs a contiguous memory read into that location. This seems suspicious to me, so I wanted to point it out. I suppose you have verified that the logic is correct @avbokovoy ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this comment applies here as well:
#138 (comment)