Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3942d12
Add gfx950 build support + fp16 fix + index type fix
avbokovoy Jul 29, 2025
7a20f4c
Change int64_t to index_t as template parameters in load_raw_per_warp
avbokovoy Jul 29, 2025
1236b3a
Implement llvm fp16 buffer load for gfx950
avbokovoy Jul 29, 2025
3bc1ba8
Fix c-style half to float cast
avbokovoy Aug 11, 2025
6dcd104
Patch 256 half stores
avbokovoy Aug 11, 2025
dc3d3e0
cta_per_row workgroup optim
shbiswas834 Aug 8, 2025
c5a6b25
Added mi350 guards
shbiswas834 Aug 11, 2025
1057a22
Fix index overflow in row load
shbiswas834 Aug 12, 2025
e981269
cta_per_row workgroup reduce by 4 optim
shbiswas834 Aug 12, 2025
3ef8e56
Fix mixed_D frontend to backend connection
avbokovoy Aug 13, 2025
caf2e9e
changed max_segment_length_per_cta to 4096
kudomcho Aug 15, 2025
4b841bd
added rocm guards and removed comment
shbiswas834 Aug 18, 2025
63287f6
clean debug statements in Hip.cmake
liligwu Aug 20, 2025
b6d76f9
Merge pull request #121
shbiswas834 Aug 28, 2025
76d0914
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
a56c299
fix the bug in dimention 160 in ROCm optimization
liligwu Sep 18, 2025
54f296d
Cleanup optimized warp_per_raw kernel
avbokovoy Aug 19, 2025
b3b9868
Add 320 embedding dim support for optimized warp_per_row kernel
avbokovoy Aug 20, 2025
8ae4724
changed the max length per warp and cta per row WG size
Sep 8, 2025
a5f48da
added DPP and changed max length per warp to 16k
kudomcho Sep 9, 2025
a4cceb7
guard max segment warp based on emb dim
kudomcho Sep 10, 2025
8b4f25c
added guarding opt of max segment for the case batch size list=1
kudomcho Sep 10, 2025
bf4769b
added condition to apply DPP warp reduce sum when emb matches backwar…
kudomcho Sep 16, 2025
2c44b8d
reverted pt2 autograd version back by 1 commit
kudomcho Sep 16, 2025
711e565
enabled DPP if emb dim in range of backward opt
Sep 16, 2025
af7392a
opt for grad_indice_weights kernel
Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ class {{ autograd_func }} :

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
Expand Down
141 changes: 132 additions & 9 deletions fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %}
{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %}

{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
optimizer == "rowwise_adagrad" and
not dense and
not nobag and
not is_index_select and
not is_gwd_kernel and
not vbe and
not ssd %}
////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -22,7 +29,9 @@
#include "fbgemm_gpu/utils/tensor_utils.h"
#include "fbgemm_gpu/utils/assert_macros.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

{%- if is_rocm %}
#include "fbgemm_gpu/rocm/cdna_guard.h"
{%- endif %}
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand Down Expand Up @@ -67,7 +76,8 @@ template <
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread
int32_t kFixedMaxVecsPerThread,
bool embDimMatch
>
__global__ __launch_bounds__(kForwardMaxThreads) void
{{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_{{ vbdesc }}kernel(
Expand Down Expand Up @@ -210,7 +220,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
)
{%- endif %}

for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int32_t j = 0;
{%- if not ssd and not dense and not use_vec_blocking and not vbe %}
// Currently for split_embedding_codegen_grad_indice_weights_kernel only
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);

const auto cache_idx_j0 = shfl_sync(cache_idx, j);
const auto cache_idx_j1 = shfl_sync(cache_idx, j+1);
const auto cache_idx_j2 = shfl_sync(cache_idx, j+2);
const auto cache_idx_j3 = shfl_sync(cache_idx, j+3);

at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;

[[maybe_unused]] const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
[[maybe_unused]] const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
[[maybe_unused]] const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
[[maybe_unused]] const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;

Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
if (placement == PlacementType::MANAGED_CACHING) {
weight0 = (cache_idx_j0 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) :
weight_row0.load(d);

weight1 = (cache_idx_j1 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) :
weight_row1.load(d);

weight2 = (cache_idx_j2 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) :
weight_row2.load(d);

weight3 = (cache_idx_j3 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) :
weight_row3.load(d);
} else {
weight0 = weight_row0.load(d);
weight1 = weight_row1.load(d);
weight2 = weight_row2.load(d);
weight3 = weight_row3.load(d);
}

grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
}

grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);

if (threadIdx.x == 0) {
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
}
}
{%- endif %}
for (; j < kWarpSize && l_start + j < L; ++j) {
const auto offset_idx_j = shfl_sync(offset_idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
Expand Down Expand Up @@ -261,7 +346,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
{%- endif %}
}
grad_indice_weight =
warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight);
warpReduceAllSum<at::acc_type<cache_t, true>, kWarpSize, embDimMatch>(grad_indice_weight);
if (threadIdx.x == 0) {
{%- if use_vec_blocking %}
if (vec_start == 0) {
Expand Down Expand Up @@ -359,7 +444,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output);

CUDA_DEVICE_GUARD(dev_weights);

#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

const auto T = D_offsets.size(0) - 1;
TORCH_CHECK_GT(T, 0);
// offsets = [B x T + 1]
Expand Down Expand Up @@ -407,13 +501,42 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
"{}_embedding_codegen_grad_indice_weights{}_{}kernel".format(
mdesc, vdesc, vbdesc)
%}
FBGEMM_LAUNCH_KERNEL(
({{ kernel_name }}<
auto kernel_name_ = {{ kernel_name }}<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread>),
kFixedMaxVecsPerThread,
/*embDimMatch=*/ false>;
#ifdef USE_ROCM
{%- if is_optimized_hip_kernel_supported_mode %}
const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half
|| dev_weights.scalar_type() == at::ScalarType::Float;

if (!mixed_D && supported_weights_type && rocm::is_supported_cdna())
{
{%- for kDimSize in [64, 128, 160, 192, 256, 320] %}
{%- for kWeightDecayMode in [0, 1, 2] %}
if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }})
{
kernel_name_ =
{{ kernel_name }}
<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread,
/*embDimMatch=*/ true
>;
}
{%- endfor %}
{%- endfor %}
}
{%- endif %}
#endif
FBGEMM_LAUNCH_KERNEL(
kernel_name_,
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@

{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %}
{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
optimizer == "rowwise_adagrad" and
not dense and
not nobag and
not is_index_select and
not is_gwd_kernel and
not vbe and
not ssd %}

#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
Expand Down Expand Up @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row

{%- endif %}

{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %}
{%- if is_optimized_hip_kernel_supported_mode %}
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
Expand Down Expand Up @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
{%- endif %}
) {
{%- if not nobag %}
int32_t T = D_offsets.size(0) - 1;
{%- else %}
int32_t T = weights_offsets.size(0);
{%- endif %}


auto p_output_grad = grad_output.data();
auto p_emb_table = dev_weights.data();
auto p_hash_size_cumsum = hash_size_cumsum.data();
Expand All @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
constexpr int32_t segment_prefetch = 2;
constexpr int32_t segment_unroll = 8;
constexpr int32_t segment_split = 0;
auto batch = grad_output.size(0);
auto num_rows = dev_weights.size(0) / T / max_D;
{%- if weighted %}
constexpr bool is_weighted = true;
{%- else %}
Expand All @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
// weight_decay(_mode) is supplied as args.split_function_args_no_defaults
opt_karg.weight_decay_mode = weight_decay_mode_v;
opt_karg.weight_decay = weight_decay;
auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t {
assert(d >= 1 && d <= INT32_MAX);
uint8_t shift;
for(shift = 0; shift < 32; shift++)
if((1U << shift) >= d)
break;

uint64_t one = 1;
uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1;
assert(magic <= 0xffffffffUL);

rocm::magic_div_u32_t result;
result.magic = magic;
result.shift = shift;
return result;
}(batch);

rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_kernel_arg_t,
Expand All @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
p_sorted_linear_indices_run,
p_sorted_linear_indices_cumulative_run_lengths,
p_sorted_linear_indices_num_runs,
{%- if not nobag %}
info_B_num_bits,
info_B_mask,
{%- endif %}
p_sorted_infos,
batch_mdiv,
max_segment_length_per_warp,
emb_dim,
batch,
num_rows,
T,
opt_karg
{%- if weighted %}
Expand Down Expand Up @@ -784,7 +766,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
Expand Down
Loading