Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
85caa29
adds optimized path for small dimension sizes to group_index_select_o…
aryaman-gupta Dec 12, 2025
e4f1dba
group_index_select_or_add_2d_kernel: splits into two kernels dependin…
aryaman-gupta Dec 16, 2025
ff1b9b6
sparse_group_index.cu: edits some comments
aryaman-gupta Dec 16, 2025
439a51a
adds USE_ROCM guards to subwarp optimizations for group_index_select_…
aryaman-gupta Dec 16, 2025
2a85d73
sparse_group_index: handle UNROLL_FACTOR for small dimensions in grou…
aryaman-gupta Dec 18, 2025
2f54140
sparse_group_index: handle fixed-column-size case correctly in optimi…
aryaman-gupta Dec 18, 2025
e0edc40
group_index_select_or_add_2d_kernel: when num_cols < UNROLL_FACTOR, d…
aryaman-gupta Dec 18, 2025
81bf648
sparse_group_index: corrects macro invoking group_index_select_or_add…
aryaman-gupta Jan 7, 2026
93b5a2e
Merge branch 'aryaman/group-index-subwarp' into aryaman/group-index-o…
aryaman-gupta Jan 12, 2026
17d8d4c
fixes merge commit
aryaman-gupta Jan 12, 2026
b6cec91
sparse_group_index.cu: copies updates into small kernel
aryaman-gupta Jan 12, 2026
4576e59
sparse_ops_gpu.cpp: corrects handling of forward outputs and adds deb…
aryaman-gupta Jan 14, 2026
71be9d2
sparse_ops_cpu: adjusts to handle potentially higher amount of saved …
aryaman-gupta Jan 14, 2026
e5dcf52
Revert "sparse_ops_gpu.cpp: corrects handling of forward outputs and …
aryaman-gupta Jan 14, 2026
f24a6cd
sparse_ops_gpu.cpp: use different for small and large embedding func…
aryaman-gupta Jan 14, 2026
39afe28
sparse_ops_gpu.cpp: corrects computation of for multi-D inputs in ba…
aryaman-gupta Jan 15, 2026
10b692d
Revert "sparse_ops_gpu.cpp: corrects computation of for multi-D inpu…
aryaman-gupta Jan 15, 2026
b9ae864
sparse_ops_gpu: corrects group_size arguments for split kernels
aryaman-gupta Jan 15, 2026
c54af4f
sparse_ops_gpu.cpp: corrects computation of for multi-D inputs in ba…
aryaman-gupta Jan 15, 2026
65a3f84
sparse_ops_gpu: use const for temporary variables
aryaman-gupta Jan 15, 2026
ab14c59
sparse_group_index: prunes code in small and large group index select…
aryaman-gupta Jan 15, 2026
598b2de
Revert D87922263 (#5207)
q10 Dec 10, 2025
6a3fef1
combines split group_index_select_or_add_2d_kernel kernel with compil…
aryaman-gupta Jan 30, 2026
554466f
Merge remote-tracking branch 'origin/main_12162025_upstream' into ary…
aryaman-gupta Feb 2, 2026
adb728e
fixes merge issues
aryaman-gupta Feb 2, 2026
9310b72
guard changes with USE_ROCM
aryaman-gupta Feb 2, 2026
63fa242
sparse_ops_cpu.cpp: corrects TORCH_CHECK commands
aryaman-gupta Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,14 +982,37 @@ def group_index_select_dim0_gpu_impl_abstract(
# divide by 2 since sizeof(int64_t) / sizeof(int32_t) = 2
args_tensor_numel = 4 * group_size + 1 + int(math.ceil(group_size / 2))

ret.append(
# sizeof(int64_t) = 8, torch.uint8 = at::kByte
input_group[0].new_empty(
args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
# Runtime check for ROCm vs CUDA
is_rocm = torch.version.hip is not None

if is_rocm:
# ROCm: Allocate dual args_tensors and saved_data tensors
ret.append(
# sizeof(int64_t) = 8, torch.uint8 = at::kByte
input_group[0].new_empty(
args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
)
)

ret.append(
# sizeof(int64_t) = 8, torch.uint8 = at::kByte
input_group[0].new_empty(
args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
)
)

ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))
ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))
else:
# CUDA: Allocate single args_tensor and saved_data tensor
ret.append(
# sizeof(int64_t) = 8, torch.uint8 = at::kByte
input_group[0].new_empty(
args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
)
)
)

ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))
ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))

return ret

Expand Down
6 changes: 5 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,11 @@ void group_index_select_or_add_cuda(
const int64_t total_num_warps,
const int group_size,
const bool use_index_select,
const bool use_var_cols);
const bool use_var_cols
#ifdef USE_ROCM
,const bool use_small_emb_dim
#endif
);

int get_group_index_select_cols_per_warp();

Expand Down
126 changes: 101 additions & 25 deletions fbgemm_gpu/src/sparse_ops/sparse_group_index.cu

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this offline. I believe that code duplication might be reduced by having unified kernel template for both Nvidia and ROCm paths, and using std::visit with platform specific std::variant as in https://github.com/ROCm/FBGEMM/pull/139/changes#diff-6f509196a8893b5345f5e615251ce85ea5f575b81c1e9136fff764a899d92562R329-R407. This way we won't need to duplicate invoke macro and compilation time/binary size will be aligned with expectations for each platform. It will also make adding new template parameters easier.

Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,26 @@ int get_group_index_select_unroll_factor() {
return GROUP_INDEX_SELECT_UNROLL_FACTOR;
}

#ifdef USE_ROCM
template <
typename index_t,
typename scalar_t,
bool USE_INDEX_SELECT,
bool USE_VAR_COLS,
bool USE_SMALL_EMB_DIM,
int UNROLL_FACTOR,
int COLS_PER_WARP,
int LOG_COLS_PER_WARP>
#else
template <
typename index_t,
typename scalar_t,
bool USE_INDEX_SELECT,
bool USE_VAR_COLS,
int UNROLL_FACTOR,
int COLS_PER_WARP,
int LOG_COLS_PER_WARP>
#endif
__global__
__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t* input_ptrs,
Expand Down Expand Up @@ -84,39 +96,39 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
// All columns are the same
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
#ifdef USE_ROCM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to wrap it in USE_ROCM if we already have a if constexpr (USE_SMALL_EMB_DIM)?

if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Need to ensure that [member_id] and [member_warp_id] are calculated
// correctly for the small embedding dimension path below
if constexpr (USE_SMALL_EMB_DIM) {
// Small embedding: pack multiple rows per warp
const auto rows_per_warp = COLS_PER_WARP / num_cols;
const auto warps_per_member =
DIV_ROUND_UP(num_work_rows, rows_per_warp);
member_id = warp_id / warps_per_member;
member_warp_id = warp_id % warps_per_member;
} else {
#endif
// Large embedding: one or more warps per row
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
#ifdef USE_ROCM
}
#endif // USE_ROCM
#endif
}

#ifdef USE_ROCM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Optimized path for small embedding dimensions
if constexpr (USE_SMALL_EMB_DIM) {
// Small embedding dimension: pack multiple rows per warp
// Each warp processes 'rows_per_warp' rows
const auto rows_per_warp = COLS_PER_WARP / num_cols;
const int64_t start_row = member_warp_id * rows_per_warp;
int64_t start_row = member_warp_id * rows_per_warp;

// Since we are processing multiple rows within the warp, we need to
// map each lane to a specific row, in addition to the column
const auto local_row = (threadIdx.x * UNROLL_FACTOR) /
num_cols; // the row ID within the set of rows handled by this warp
const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols;
const int64_t current_row = start_row +
local_row; // the actual row within the table processed by this lane

// local_row may be out of bounds for the last few lanes in the warp if
// [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are
// within num_work_rows
int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp
int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols;
int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane

// local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0]
// and we also need to confirm that we are within num_work_rows
Comment on lines +122 to +131

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we removed const qualifier?

if (local_row < rows_per_warp && current_row < num_work_rows) {
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
Expand All @@ -129,18 +141,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[current_row * num_cols + i] =
LDG(&input[idx * num_cols + i]);
output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[current_row * num_cols + i]);
}
}
}
} else {
// Large embedding dimensions use >= 1 warp per row
// which is the default codepath for non-ROCm as well
#endif // USE_ROCM
#endif
// Large embedding dimension: one or more warps per row
const auto row = member_warp_id / warps_per_row;
const auto col_offset =
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
Expand All @@ -164,7 +174,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
#ifdef USE_ROCM
}
#endif // USE_ROCM
#endif
}
}

Expand All @@ -181,7 +191,11 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
const int64_t total_num_warps,
const int group_size,
const bool use_index_select,
const bool use_var_cols) {
const bool use_var_cols
#ifdef USE_ROCM
,const bool use_small_emb_dim

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may force use_small_emb_dim=false on caller side in case of CUDA and keep API consistent

#endif
) {
if (group_size == 0) {
return;
}
Expand All @@ -197,6 +211,32 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
max_grid_size);
dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);

#ifdef USE_ROCM
// Kernel launch macro with USE_SMALL_EMB_DIM template parameter
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS, USE_SMALL_EMB_DIM) \
FBGEMM_LAUNCH_KERNEL( \
(group_index_select_or_add_2d_kernel< \
index_t, \
scalar_t, \
USE_INDEX_SELECT, \
USE_VAR_COLS, \
USE_SMALL_EMB_DIM, \
GROUP_INDEX_SELECT_UNROLL_FACTOR, \
GROUP_INDEX_SELECT_COLS_PER_WARP, \
GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \
grid_size, \
block_size, \
0, \
at::cuda::getCurrentCUDAStream(), \
input_ptrs, \
output_ptrs, \
indices_ptrs, \
warp_offsets_group, \
num_cols_group, \
num_work_rows, \
group_size)
#else
// Kernel launch macro for CUDA (no USE_SMALL_EMB_DIM)
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
FBGEMM_LAUNCH_KERNEL( \
(group_index_select_or_add_2d_kernel< \
Expand All @@ -218,11 +258,46 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
num_cols_group, \
num_work_rows, \
group_size)
#endif

AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {
FBGEMM_DISPATCH_FLOATING_TYPES(
input_scalar_type, "group_index_select_2d_wrapper_2", [&] {
#ifdef USE_ROCM
if (use_small_emb_dim) {
// Small embedding dimension: pack multiple rows per warp
if (use_index_select) {
if (use_var_cols) {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, true);
} else {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, true);
}
} else {
if (use_var_cols) {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, true);
} else {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, true);
}
}
} else {
// Large embedding dimension: one or more warps per row
if (use_index_select) {
if (use_var_cols) {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, false);
} else {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, false);
}
} else {
if (use_var_cols) {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, false);
} else {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, false);
}
}
}
#else
// CUDA: Standard path only (no small embedding optimization)
if (use_index_select) {
if (use_var_cols) {
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true);
Expand All @@ -236,6 +311,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false);
}
}
#endif
});
});

Expand Down
30 changes: 23 additions & 7 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3508,9 +3508,14 @@ torch::autograd::variable_list group_index_select_dim0(
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "")
.typed<decltype(group_index_select_dim0_autograd_impl)>();

auto res = forward_op.call(
all_indices_input_tensor, static_cast<int64_t>(group_size));
TORCH_CHECK(res.size() == group_size + 2);
all_indices_input_tensor, static_cast<int64_t>(group_size));
#ifdef USE_ROCM
TORCH_CHECK(res.size() >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have an exact comparison against expected size since >= might create a room for potential run-time errors in future.

#else
TORCH_CHECK(res.size() == group_size + 2); // CUDA: +2 tensors (1 args, 1 saved_data)
#endif
// only return the outputs (the first group_size elements)
res.resize(group_size);
return res;
Expand Down Expand Up @@ -3621,7 +3626,12 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward(
.findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "")
.typed<decltype(group_index_select_dim0_forward_impl_cpu)>();
auto result = forward_op.call(all_indices_input, group_size);
TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2);
#ifdef USE_ROCM
TORCH_CHECK(static_cast<int64_t>(result.size()) >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#else
TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2); // CUDA: +2 tensors
#endif
ctx->saved_data["group_size"] = group_size;

auto [input_group, indices_group] =
group_index_select_dim0_unpack(all_indices_input, group_size);
Expand Down Expand Up @@ -3654,17 +3664,23 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward(
torch::autograd::variable_list GroupIndexSelectDim0Op::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output_group) {
TORCH_CHECK(grad_output_group.size() >= 2);
if (grad_output_group.size() == 2) {

auto group_size = ctx->saved_data["group_size"].toInt();
TORCH_CHECK(static_cast<int64_t>(grad_output_group.size()) >= group_size);

if (group_size == 0) {
// empty outputs
return torch::autograd::variable_list(1);
}
// remove redundant grads
auto group_size = grad_output_group.size() - 2;
grad_output_group.resize(group_size);

const auto saved_tensors = ctx->get_saved_variables();
TORCH_CHECK(saved_tensors.size() == group_size + 3);
#ifdef USE_ROCM
TORCH_CHECK(saved_tensors.size() >= group_size + 3); // ROCm: >= to handle both CPU (+3) and GPU (+5)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#else
TORCH_CHECK(saved_tensors.size() == group_size + 3); // CUDA: indices + 2 tensors + fwd_input
#endif
std::vector<c10::SymInt> output_shape_group;
int i = 0;
while (true) {
Expand Down
Loading