diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 25aca5336b..d8f5c3b386 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -165,10 +165,10 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no #pragma unroll for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { - uint4 row_data_v[kRowUnroll]; const uint4* row_v[kRowUnroll]; int32_t idx_v[kRowUnroll]; int32_t cache_idx_v[kRowUnroll]; + bool row_valid_v[kRowUnroll]; #pragma unroll for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { uint32_t i = outer_i + inner_i; @@ -176,50 +176,54 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + row_valid_v[inner_i] = valid; } #pragma unroll for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - valid = valid && (idx_v[inner_i] != -1); + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && row_valid_v[inner_i]); + bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1); if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); - } else - if (valid) { + } else if (final_valid) { row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); } else { row_v[inner_i] = reinterpret_cast(&weights[0]); } + row_valid_v[inner_i] = final_valid; } #pragma unroll for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { uint32_t i = outer_i + inner_i; - row_data_v[inner_i] = row_v[inner_i][row_load_idx]; - } - uint4 zeros = {0, 0, 0, 0}; - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); - uint4 data = valid ? row_data_v[inner_i] : zeros; + bool final_valid = row_valid_v[inner_i]; if constexpr (PackedMode) { // Store row data with uint4_loads_per_row offset - buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + cp_async_zfill_cg( + &buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], + &row_v[inner_i][row_load_idx], + final_valid); } else { - buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + cp_async_zfill_cg( + &buffers[warp_idx][i][input_row_idx][row_load_idx], + &row_v[inner_i][row_load_idx], + final_valid); } - {% if weighted %} + } + {% if weighted %} + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1); if (row_load_idx == 0) { // Use only one thread to load the index weight to prevent a race // condition when writing to the shared memory buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = - valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + final_valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; } - {% endif %} } + {% endif %} } {%- endif %} diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index f05bed94d6..9efdac4d01 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -170,6 +170,10 @@ __device__ __forceinline__ void cp_async_wait() { #if __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \ + (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__) + + __builtin_amdgcn_s_waitcnt(0); #endif } @@ -179,13 +183,27 @@ __device__ __forceinline__ void cp_async_wait<0>() { #if __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_all;\n" ::); +#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \ + (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__) + + __builtin_amdgcn_s_waitcnt(0); #endif } +template +__device__ __forceinline__ uint32_t hip_cvta_to_shared_address(const T* ptr) { + // First get the address as a size_t to handle all pointer sizes + size_t addr = reinterpret_cast(ptr); + + // Extract the lower 32 bits which represent the shared memory offset + // This is safe because shared memory addresses are always within 32-bit range + return static_cast(addr & 0xFFFFFFFF); +} + /// Partial specialization template __device__ __forceinline__ void -cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) { +cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_guard) { #if __CUDA_ARCH__ >= 800 static_assert( SizeInBytes == 16, @@ -199,6 +217,36 @@ cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) { "n"(SizeInBytes), "r"(src_in_bytes)); +// if ROCm version >= 7.2 and MI350 +#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR > 7 || \ + (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR >= 2)) && defined(__gfx950__) + static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; + static_assert( + SizeInBytes == 16, + "cp_async_zfill_cg() function is implemented for 16B inputs only"); + // Due to LLVM bug, we can't use SizeInBytes directly + // in __builtin_amdgcn_global_load_lds intrinsic until + // ROCm 7.11: + // https://github.com/llvm/llvm-project/pull/175767 + // + // Make sure you modify this #if branch if SizeInBytes + // support range is extended + const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; + __builtin_amdgcn_global_load_lds(const_cast(src_ptr), smem_ptr, 16, 0, 0); +// if MI350 +#elif defined(USE_ROCM) && defined(__gfx950__) + static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0}; + static_assert( + SizeInBytes == 16, + "cp_async_zfill_cg() function is implemented for 16B inputs only"); + + uint32_t smem = + __builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr)); + const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile; + asm volatile("s_mov_b32 m0, %0\n" + "global_load_lds_dwordx4 %1, off\n" ::"s"(smem), + "v"(static_cast(src_ptr)) + :); #else static_assert(SizeInBytes == 16, ""); using AccessType = uint4;