Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,61 +165,65 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no

#pragma unroll
for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) {
uint4 row_data_v[kRowUnroll];
const uint4* row_v[kRowUnroll];
int32_t idx_v[kRowUnroll];
int32_t cache_idx_v[kRowUnroll];
bool row_valid_v[kRowUnroll];
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1;
cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1;
row_valid_v[inner_i] = valid;
}


#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
valid = valid && (idx_v[inner_i] != -1);
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && row_valid_v[inner_i]);
bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1);
if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&lxu_cache_weights[static_cast<int64_t>(cache_idx_v[inner_i])][0]);
} else
if (valid) {
} else if (final_valid) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[static_cast<int64_t>(idx_v[inner_i]) * D_bytes]);
} else {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[0]);
}
row_valid_v[inner_i] = final_valid;
}
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
row_data_v[inner_i] = row_v[inner_i][row_load_idx];
}
uint4 zeros = {0, 0, 0, 0};
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1);
uint4 data = valid ? row_data_v[inner_i] : zeros;
bool final_valid = row_valid_v[inner_i];
if constexpr (PackedMode) {
// Store row data with uint4_loads_per_row offset
buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data;
cp_async_zfill_cg<sizeof(uint4)>(
&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx],
&row_v[inner_i][row_load_idx],
final_valid);
Comment on lines +203 to +206

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that in PackedMode, the smem_ptr passed to the different lanes in the cp_async_zfill_cg function is strided. However, the cp_async_zfill_cg function uses lane 0's smem_ptr and performs a contiguous memory read into that location. This seems suspicious to me, so I wanted to point it out. I suppose you have verified that the logic is correct @avbokovoy ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this comment applies here as well:
#138 (comment)

} else {
buffers[warp_idx][i][input_row_idx][row_load_idx] = data;
cp_async_zfill_cg<sizeof(uint4)>(
&buffers[warp_idx][i][input_row_idx][row_load_idx],
&row_v[inner_i][row_load_idx],
final_valid);
}
{% if weighted %}
}
{% if weighted %}
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1);
if (row_load_idx == 0) {
// Use only one thread to load the index weight to prevent a race
// condition when writing to the shared memory
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] =
valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
final_valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
}
{% endif %}
}
{% endif %}
}
{%- endif %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ __device__ __forceinline__ void cp_async_wait() {
#if __CUDA_ARCH__ >= 800

asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \
(ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__)

__builtin_amdgcn_s_waitcnt(0);
#endif
}

Expand All @@ -179,13 +183,27 @@ __device__ __forceinline__ void cp_async_wait<0>() {
#if __CUDA_ARCH__ >= 800

asm volatile("cp.async.wait_all;\n" ::);
#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \
(ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__)

__builtin_amdgcn_s_waitcnt(0);
#endif
}

template<typename T>
__device__ __forceinline__ uint32_t hip_cvta_to_shared_address(const T* ptr) {
// First get the address as a size_t to handle all pointer sizes
size_t addr = reinterpret_cast<size_t>(ptr);

// Extract the lower 32 bits which represent the shared memory offset
// This is safe because shared memory addresses are always within 32-bit range
return static_cast<uint32_t>(addr & 0xFFFFFFFF);
}

/// Partial specialization
template <int SizeInBytes>
__device__ __forceinline__ void
cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) {
cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_guard) {
#if __CUDA_ARCH__ >= 800
static_assert(
SizeInBytes == 16,
Expand All @@ -199,6 +217,36 @@ cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) {
"n"(SizeInBytes),
"r"(src_in_bytes));

// if ROCm version >= 7.2 and MI350
#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR > 7 || \
(ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR >= 2)) && defined(__gfx950__)
static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0};
static_assert(
SizeInBytes == 16,
"cp_async_zfill_cg() function is implemented for 16B inputs only");
// Due to LLVM bug, we can't use SizeInBytes directly
// in __builtin_amdgcn_global_load_lds intrinsic until
// ROCm 7.11:
// https://github.com/llvm/llvm-project/pull/175767
//
// Make sure you modify this #if branch if SizeInBytes
// support range is extended
const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile;
__builtin_amdgcn_global_load_lds(const_cast<void*>(src_ptr), smem_ptr, 16, 0, 0);
// if MI350
#elif defined(USE_ROCM) && defined(__gfx950__)
static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0};
static_assert(
SizeInBytes == 16,
"cp_async_zfill_cg() function is implemented for 16B inputs only");

uint32_t smem =
__builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr));
const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile;
asm volatile("s_mov_b32 m0, %0\n"
"global_load_lds_dwordx4 %1, off\n" ::"s"(smem),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enforces that the entire warp will load a contiguous chunk of memory from global to LDS. What happens when the row is not large enough, i.e., kWarpSize > NumUint4LoadsPerRow? As I understand it, this would assign different row_load_idx to different lanes in the wavefront

uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow;

You might want to confirm that this case is correctly handled

Copy link
Author

@avbokovoy avbokovoy Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will load 16 bytes (16 x 64 for the whole wave) from corresponding vector register (address is different from lane to lane) into LDS pointer with corresponding strides. Global memory doesn't have to be contiguous. The sanity of the loads are checked outside of this function and is handled with pred_guard. Tailing or OOB loads are redirected to zero_tile global memory chunk, which contains zeroes. It's then handled properly by kernel logic

"v"(static_cast<const uint32_t *>(src_ptr))
:);
#else
static_assert(SizeInBytes == 16, "");
using AccessType = uint4;
Expand Down