From 856f1af37fb2205083045ca5036f0ec7896b0e74 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 19 Dec 2025 10:34:09 +0000 Subject: [PATCH 1/3] Implement asynchronous LDS loads for MI350 --- ...rd_quantized_split_nbit_kernel_template.cu | 42 ++++++++++-------- .../embedding_forward_template_helpers.cuh | 43 ++++++++++++++++++- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 25aca5336b..d8f5c3b386 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -165,10 +165,10 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no #pragma unroll for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { - uint4 row_data_v[kRowUnroll]; const uint4* row_v[kRowUnroll]; int32_t idx_v[kRowUnroll]; int32_t cache_idx_v[kRowUnroll]; + bool row_valid_v[kRowUnroll]; #pragma unroll for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { uint32_t i = outer_i + inner_i; @@ -176,50 +176,54 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + row_valid_v[inner_i] = valid; } #pragma unroll for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - valid = valid && (idx_v[inner_i] != -1); + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && row_valid_v[inner_i]); + bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1); if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); - } else - if (valid) { + } else if (final_valid) { row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); } else { row_v[inner_i] = reinterpret_cast(&weights[0]); } + row_valid_v[inner_i] = final_valid; } #pragma unroll for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { uint32_t i = outer_i + inner_i; - row_data_v[inner_i] = row_v[inner_i][row_load_idx]; - } - uint4 zeros = {0, 0, 0, 0}; - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); - uint4 data = valid ? row_data_v[inner_i] : zeros; + bool final_valid = row_valid_v[inner_i]; if constexpr (PackedMode) { // Store row data with uint4_loads_per_row offset - buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + cp_async_zfill_cg( + &buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], + &row_v[inner_i][row_load_idx], + final_valid); } else { - buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + cp_async_zfill_cg( + &buffers[warp_idx][i][input_row_idx][row_load_idx], + &row_v[inner_i][row_load_idx], + final_valid); } - {% if weighted %} + } + {% if weighted %} + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1); if (row_load_idx == 0) { // Use only one thread to load the index weight to prevent a race // condition when writing to the shared memory buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = - valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + final_valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; } - {% endif %} } + {% endif %} } {%- endif %} diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index f05bed94d6..de45161b56 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -170,6 +170,10 @@ __device__ __forceinline__ void cp_async_wait() { #if __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#elif defined(USE_ROCM) && \ + (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) + + __builtin_amdgcn_s_waitcnt(0); #endif } @@ -179,13 +183,27 @@ __device__ __forceinline__ void cp_async_wait<0>() { #if __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_all;\n" ::); +#elif defined(USE_ROCM) && \ + (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) + + __builtin_amdgcn_s_waitcnt(0); #endif } +template +__device__ __forceinline__ uint32_t hip_cvta_to_shared_address(const T* ptr) { + // First get the address as a size_t to handle all pointer sizes + size_t addr = reinterpret_cast(ptr); + + // Extract the lower 32 bits which represent the shared memory offset + // This is safe because shared memory addresses are always within 32-bit range + return static_cast(addr & 0xFFFFFFFF); +} + /// Partial specialization template __device__ __forceinline__ void -cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) { +cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_guard) { #if __CUDA_ARCH__ >= 800 static_assert( SizeInBytes == 16, @@ -199,6 +217,29 @@ cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) { "n"(SizeInBytes), "r"(src_in_bytes)); +#elif defined(USE_ROCM) + static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; + static_assert( + SizeInBytes == 16, + "cp_async_zfill_cg() function is implemented for 16B inputs only"); + +// if ROCm version >= 7.2 and MI350 +#if (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) && defined(__gfx950__) + + const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; + __builtin_amdgcn_global_load_lds(src_ptr, smem_ptr, SizeInBytes, 0, 0); +// if ROCm version in [7.0, 7.2) and MI350 +#elif ROCM_VERSION_MAJOR >= 7 && defined(__gfx950__) + + uint32_t smem = + __builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr)); + const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; + asm volatile("s_mov_b32 m0, %0\n" + "global_load_lds_dwordx4 %1, off\n" ::"s"(smem), + "v"(static_cast(src_ptr)) + :); +#endif // (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) || + // (ROCM_VERSION_MAJOR > 7) && defined(__gfx950__) #else static_assert(SizeInBytes == 16, ""); using AccessType = uint4; From dc3b15ba99eb4195db982dda20f63fc4fb6cc2f8 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 14 Jan 2026 10:15:02 +0000 Subject: [PATCH 2/3] Hardcode size value in __builtin_amdgcn_global_load_lds intrinsic --- .../embedding_forward_template_helpers.cuh | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index de45161b56..57da91cc7f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -225,9 +225,15 @@ cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_g // if ROCm version >= 7.2 and MI350 #if (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) && defined(__gfx950__) - + // Due to LLVM bug, we can't use SizeInBytes directly + // in __builtin_amdgcn_global_load_lds intrinsic until + // ROCm 7.11: + // https://github.com/llvm/llvm-project/pull/175767 + // + // Make sure you modify this #if branch if SizeInBytes + // support range is extended const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; - __builtin_amdgcn_global_load_lds(src_ptr, smem_ptr, SizeInBytes, 0, 0); + __builtin_amdgcn_global_load_lds(src_ptr, smem_ptr, 16, 0, 0); // if ROCm version in [7.0, 7.2) and MI350 #elif ROCM_VERSION_MAJOR >= 7 && defined(__gfx950__) @@ -238,8 +244,8 @@ cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_g "global_load_lds_dwordx4 %1, off\n" ::"s"(smem), "v"(static_cast(src_ptr)) :); -#endif // (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) || - // (ROCM_VERSION_MAJOR > 7) && defined(__gfx950__) +#endif // (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) && defined(__gfx950__) + #else static_assert(SizeInBytes == 16, ""); using AccessType = uint4; From 2c739ab4ecfe0c313cbaeb9dcc7ede42dba04073 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 19 Feb 2026 11:25:07 +0000 Subject: [PATCH 3/3] Fix ROCm version and arch guards --- .../embedding_forward_template_helpers.cuh | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index 57da91cc7f..9efdac4d01 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -170,8 +170,8 @@ __device__ __forceinline__ void cp_async_wait() { #if __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); -#elif defined(USE_ROCM) && \ - (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) +#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \ + (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__) __builtin_amdgcn_s_waitcnt(0); #endif @@ -183,8 +183,8 @@ __device__ __forceinline__ void cp_async_wait<0>() { #if __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_all;\n" ::); -#elif defined(USE_ROCM) && \ - (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) +#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \ + (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__) __builtin_amdgcn_s_waitcnt(0); #endif @@ -217,14 +217,13 @@ cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_g "n"(SizeInBytes), "r"(src_in_bytes)); -#elif defined(USE_ROCM) +// if ROCm version >= 7.2 and MI350 +#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR > 7 || \ + (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR >= 2)) && defined(__gfx950__) static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; static_assert( SizeInBytes == 16, "cp_async_zfill_cg() function is implemented for 16B inputs only"); - -// if ROCm version >= 7.2 and MI350 -#if (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) && defined(__gfx950__) // Due to LLVM bug, we can't use SizeInBytes directly // in __builtin_amdgcn_global_load_lds intrinsic until // ROCm 7.11: @@ -233,9 +232,13 @@ cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_g // Make sure you modify this #if branch if SizeInBytes // support range is extended const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; - __builtin_amdgcn_global_load_lds(src_ptr, smem_ptr, 16, 0, 0); -// if ROCm version in [7.0, 7.2) and MI350 -#elif ROCM_VERSION_MAJOR >= 7 && defined(__gfx950__) + __builtin_amdgcn_global_load_lds(const_cast(src_ptr), smem_ptr, 16, 0, 0); +// if MI350 +#elif defined(USE_ROCM) && defined(__gfx950__) + static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; + static_assert( + SizeInBytes == 16, + "cp_async_zfill_cg() function is implemented for 16B inputs only"); uint32_t smem = __builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr)); @@ -244,8 +247,6 @@ cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_g "global_load_lds_dwordx4 %1, off\n" ::"s"(smem), "v"(static_cast(src_ptr)) :); -#endif // (ROCM_VERSION_MAJOR >= 7 && ROCM_VERSION_MINOR >= 2) && defined(__gfx950__) - #else static_assert(SizeInBytes == 16, ""); using AccessType = uint4;