Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
525 changes: 328 additions & 197 deletions fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,98 @@

using namespace fbgemm_gpu;

// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1)
#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i);
#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i);
#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i);
#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i);
#define IDX_WEIGHT(i, j) at::acc_type<cache_t, true> idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i);

#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j);
#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j);
#define REPEAT_2(X, j) X(1, j);
#define REPEAT_1(X, j) // No additional variables needed for block size 1

#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n);
#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n);
#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n);
#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1

// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1)
// if nobag and is_index_select
#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]);
// elif nobag
#define GRAD_VEC_N(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[l_j_##i][d]);
// elif vbe
#define GRAD_VEC_V(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]);
// else
#define GRAD_VEC(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d);

// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1)
#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i);
// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1)
#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i);

// Core macro: Process blocks of specified size (block_size = 8/4/2/1)
// Parameters:
// - block_size: Size of each block to process
// - unroll_count: Number of unroll iterations for the inner loop
#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \
for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \
{%- if nobag %}
int32_t l_j_0 = SHFL_SYNC(l, j); \
REPEAT_##block_size(L, j) \
{%- elif vbe %}
/* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \
const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \
/* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \
REPEAT_##block_size(GRAD_OFFSET, j) \
{%- else %}
int32_t b_j_0 = SHFL_SYNC(b, j); \
REPEAT_##block_size(B, j) \
int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \
REPEAT_##block_size(D_START, j) \
{%- endif %}
{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \
REPEAT_##block_size(IDX_WEIGHT, j) \
{%- endif %}
{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}
\
for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \
const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \
/* Generate block_size Vec4TAcc objects and accumulate them */ \
Vec4TAcc<grad_t> grad_out_vec_0( \
{%- if nobag and is_index_select %}
&grad_output[grad_offset + l_j_0 * grad_stride + d] \
{%- elif nobag %}
&grad_output[l_j_0][d] \
{%- elif vbe %}
&grad_output[0][grad_offset_j_0 + d] \
{%- else %}
&grad_output[b_j_0][0] + D_start_j_0 + d \
{%- endif %}
); \
{%- if nobag and is_index_select %}
REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \
{%- elif nobag %}
REPEAT_##block_size(GRAD_VEC_N, d) \
{%- elif vbe %}
REPEAT_##block_size(GRAD_VEC_V, d) \
{%- else %}
REPEAT_##block_size(GRAD_VEC, d) \
{%- endif %}
\
{%- if weighted %}
grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \
REPEAT_##block_size(FMA_GRAD, vec) \
{%- else %}
grad_sum[vec].add_(grad_out_vec_0); \
REPEAT_##block_size(ADD_GRAD, vec) \
{%- endif %}
} \
}

{%- if gen_once %}
{#- /*
The kernels in this section will be generated only once for all TBE configs
Expand Down Expand Up @@ -141,45 +233,21 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
? sorted_indice_weights[segment_start + sl_j]
: 0.0;
{%- endif %}
for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) {
{%- if nobag %}
int32_t l_j = SHFL_SYNC(l, j);
{%- elif vbe %}
const auto grad_offset_j = SHFL_SYNC(grad_offset, j);
{%- else %}
int32_t b_j = SHFL_SYNC(b, j);
int32_t D_start_j = SHFL_SYNC(D_start, j);
{%- endif %}

{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j = SHFL_SYNC(idx_weight, j);
{%- endif %}
int32_t j = 0;

{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) {
const int32_t d = {{ d }};
Vec4TAcc<grad_t> grad_out_vec(
{%- if nobag and is_index_select %}
// grad_output is 1d
&grad_output[grad_offset + l_j * grad_stride + d]
{%- elif nobag %}
&grad_output[l_j][d]
{%- elif vbe %}
&grad_output[0][grad_offset_j + d]
{%- else %}
&grad_output[b_j][0] + D_start_j + d
{%- endif %} // if nobag
);

{%- if weighted %}
grad_sum[vec].fma_(grad_out_vec, idx_weight_j);
{%- else %}
grad_sum[vec].add_(grad_out_vec);
{%- endif %}
}
}
// Process blocks of different sizes with loop unrolling
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
}
{%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "fbgemm_gpu/utils/assert_macros.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

{%- if is_rocm %}
#include "fbgemm_gpu/rocm/cdna_guard.h"
{%- endif %}

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand All @@ -47,6 +51,87 @@ using namespace fbgemm_gpu;
-}}
}()

// Macro to process weights loop, with cache usage controlled by 'use_cache' (0 = no cache, 1 = use cache)
#define PROCESS_WEIGHTS_LOOP(use_cache, unroll_count) \
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { \
/* Get offset indices (common logic) */ \
const auto offset_idx_j0 = shfl_sync(offset_idx, j); \
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); \
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); \
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); \
\
/* Get cache indices only if use_cache is 1 (using compile-time condition) */ \
const auto cache_idx_j0 = (use_cache) ? shfl_sync(cache_idx, j) : 0; \
const auto cache_idx_j1 = (use_cache) ? shfl_sync(cache_idx, j+1) : 0; \
const auto cache_idx_j2 = (use_cache) ? shfl_sync(cache_idx, j+2) : 0; \
const auto cache_idx_j3 = (use_cache) ? shfl_sync(cache_idx, j+3) : 0; \
\
/* Gradient weight variables (common) */ \
at::acc_type<cache_t, true> grad_indice_weight0 = 0.0; \
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0; \
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0; \
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0; \
\
/* Weight row accessors (common) */ \
const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D); \
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D); \
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D); \
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D); \
\
/* Loop over vectors to compute gradients */ \
for (int32_t vec = 0; vec < unroll_count && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { \
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; \
\
Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3; \
\
/* Load weights: choose logic based on use_cache (compile-time condition) */ \
if constexpr (use_cache) { \
/* Cache-aware loading (second code snippet logic) */ \
weight0 = (cache_idx_j0 != kCacheLocationMissing) ? \
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) : \
weight_row0.load(d); \
weight1 = (cache_idx_j1 != kCacheLocationMissing) ? \
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) : \
weight_row1.load(d); \
weight2 = (cache_idx_j2 != kCacheLocationMissing) ? \
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) : \
weight_row2.load(d); \
weight3 = (cache_idx_j3 != kCacheLocationMissing) ? \
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) : \
weight_row3.load(d); \
} else { \
/* Direct weight loading (first code snippet logic) */ \
weight0 = weight_row0.load(d); \
weight1 = weight_row1.load(d); \
weight2 = weight_row2.load(d); \
weight3 = weight_row3.load(d); \
} \
\
/* Gradient calculation (common) */ \
grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + \
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; \
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + \
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; \
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + \
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; \
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + \
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; \
} \
\
/* Warp reduction and result assignment (common) */ \
grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0); \
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1); \
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2); \
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3); \
\
if (threadIdx.x == 0) { \
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; \
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; \
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; \
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; \
} \
}

{%- for vbe in ([True, False]) %}
{%- set vdesc = "_vbe" if vbe else "" %}

Expand Down Expand Up @@ -98,8 +183,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
{%- endif %}
) {
constexpr int32_t kVecWidth = 4;
[[maybe_unused]] int error_code = 0;
[[maybe_unused]] int64_t error_value = 0;
int error_code = 0;
int64_t error_value = 0;

int32_t T = D_offsets.size(0) - 1;
auto b_t = blockIdx.x * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -210,7 +295,20 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
)
{%- endif %}

for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int32_t j = 0;
{%- if not ssd and not dense and not use_vec_blocking and not vbe %}
// Currently for split_embedding_codegen_grad_indice_weights_kernel only
if (placement != PlacementType::MANAGED_CACHING) {
// no cache logic
#pragma unroll kFixedMaxVecsPerThread
PROCESS_WEIGHTS_LOOP(0, kFixedMaxVecsPerThread)
} else {
// with cache logic
#pragma unroll kFixedMaxVecsPerThread
PROCESS_WEIGHTS_LOOP(1, kFixedMaxVecsPerThread)
}
{%- endif %}
for (; j < kWarpSize && l_start + j < L; ++j) {
const auto offset_idx_j = shfl_sync(offset_idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
Expand Down Expand Up @@ -359,6 +457,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output);

CUDA_DEVICE_GUARD(dev_weights);
#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

const auto T = D_offsets.size(0) - 1;
TORCH_CHECK_GT(T, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
codegen/embedding_common_code_generator.py for more details
*/ #}

{{ instantiate_templates(use_subwarp_shuffle=False) }}
{{ instantiate_templates(use_subwarp_shuffle=True) }}

////////////////////////////////////////////////////////////////////////////////
#endif
Expand Down
Loading