diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 2de31ab7ed..b48a159963 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -982,14 +982,37 @@ def group_index_select_dim0_gpu_impl_abstract( # divide by 2 since sizeof(int64_t) / sizeof(int32_t) = 2 args_tensor_numel = 4 * group_size + 1 + int(math.ceil(group_size / 2)) - ret.append( - # sizeof(int64_t) = 8, torch.uint8 = at::kByte - input_group[0].new_empty( - args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + # Runtime check for ROCm vs CUDA + is_rocm = torch.version.hip is not None + + if is_rocm: + # ROCm: Allocate dual args_tensors and saved_data tensors + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) + ) + + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) + ) + + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + else: + # CUDA: Allocate single args_tensor and saved_data tensor + ret.append( + # sizeof(int64_t) = 8, torch.uint8 = at::kByte + input_group[0].new_empty( + args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True + ) ) - ) - ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) + ret.append(torch.zeros(5, dtype=torch.int64, device="cpu")) return ret diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 0e7bd37234..09f5caa870 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1087,7 +1087,11 @@ void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols); + const bool use_var_cols +#ifdef USE_ROCM + ,const bool use_small_emb_dim +#endif +); int get_group_index_select_cols_per_warp(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index fa59b1c091..4b3b1d0c3f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -37,14 +37,26 @@ int get_group_index_select_unroll_factor() { return GROUP_INDEX_SELECT_UNROLL_FACTOR; } +#ifdef USE_ROCM template < typename index_t, typename scalar_t, bool USE_INDEX_SELECT, bool USE_VAR_COLS, + bool USE_SMALL_EMB_DIM, int UNROLL_FACTOR, int COLS_PER_WARP, int LOG_COLS_PER_WARP> +#else +template < + typename index_t, + typename scalar_t, + bool USE_INDEX_SELECT, + bool USE_VAR_COLS, + int UNROLL_FACTOR, + int COLS_PER_WARP, + int LOG_COLS_PER_WARP> +#endif __global__ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* input_ptrs, @@ -84,39 +96,39 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same - member_id = warp_id / (warps_per_row * num_work_rows); - member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); #ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Need to ensure that [member_id] and [member_warp_id] are calculated - // correctly for the small embedding dimension path below + if constexpr (USE_SMALL_EMB_DIM) { + // Small embedding: pack multiple rows per warp const auto rows_per_warp = COLS_PER_WARP / num_cols; const auto warps_per_member = DIV_ROUND_UP(num_work_rows, rows_per_warp); member_id = warp_id / warps_per_member; member_warp_id = warp_id % warps_per_member; + } else { +#endif + // Large embedding: one or more warps per row + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); +#ifdef USE_ROCM } -#endif // USE_ROCM +#endif } #ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Optimized path for small embedding dimensions + if constexpr (USE_SMALL_EMB_DIM) { + // Small embedding dimension: pack multiple rows per warp // Each warp processes 'rows_per_warp' rows const auto rows_per_warp = COLS_PER_WARP / num_cols; - const int64_t start_row = member_warp_id * rows_per_warp; + int64_t start_row = member_warp_id * rows_per_warp; // Since we are processing multiple rows within the warp, we need to // map each lane to a specific row, in addition to the column - const auto local_row = (threadIdx.x * UNROLL_FACTOR) / - num_cols; // the row ID within the set of rows handled by this warp - const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; - const int64_t current_row = start_row + - local_row; // the actual row within the table processed by this lane - - // local_row may be out of bounds for the last few lanes in the warp if - // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are - // within num_work_rows + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; @@ -129,8 +141,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional if constexpr (USE_INDEX_SELECT) { - output[current_row * num_cols + i] = - LDG(&input[idx * num_cols + i]); + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( &output[idx * num_cols + i], input[current_row * num_cols + i]); @@ -138,9 +149,8 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } } else { - // Large embedding dimensions use >= 1 warp per row - // which is the default codepath for non-ROCm as well -#endif // USE_ROCM +#endif + // Large embedding dimension: one or more warps per row const auto row = member_warp_id / warps_per_row; const auto col_offset = ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + @@ -164,7 +174,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } #ifdef USE_ROCM } -#endif // USE_ROCM +#endif } } @@ -181,7 +191,11 @@ DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols) { + const bool use_var_cols +#ifdef USE_ROCM + ,const bool use_small_emb_dim +#endif +) { if (group_size == 0) { return; } @@ -197,6 +211,32 @@ DLL_PUBLIC void group_index_select_or_add_cuda( max_grid_size); dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); +#ifdef USE_ROCM +// Kernel launch macro with USE_SMALL_EMB_DIM template parameter +#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS, USE_SMALL_EMB_DIM) \ + FBGEMM_LAUNCH_KERNEL( \ + (group_index_select_or_add_2d_kernel< \ + index_t, \ + scalar_t, \ + USE_INDEX_SELECT, \ + USE_VAR_COLS, \ + USE_SMALL_EMB_DIM, \ + GROUP_INDEX_SELECT_UNROLL_FACTOR, \ + GROUP_INDEX_SELECT_COLS_PER_WARP, \ + GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ + grid_size, \ + block_size, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + input_ptrs, \ + output_ptrs, \ + indices_ptrs, \ + warp_offsets_group, \ + num_cols_group, \ + num_work_rows, \ + group_size) +#else +// Kernel launch macro for CUDA (no USE_SMALL_EMB_DIM) #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \ (group_index_select_or_add_2d_kernel< \ @@ -218,11 +258,46 @@ DLL_PUBLIC void group_index_select_or_add_cuda( num_cols_group, \ num_work_rows, \ group_size) +#endif AT_DISPATCH_INDEX_TYPES( indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( input_scalar_type, "group_index_select_2d_wrapper_2", [&] { +#ifdef USE_ROCM + if (use_small_emb_dim) { + // Small embedding dimension: pack multiple rows per warp + if (use_index_select) { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, true); + } + } else { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, true); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, true); + } + } + } else { + // Large embedding dimension: one or more warps per row + if (use_index_select) { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true, false); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false, false); + } + } else { + if (use_var_cols) { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true, false); + } else { + INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false, false); + } + } + } +#else + // CUDA: Standard path only (no small embedding optimization) if (use_index_select) { if (use_var_cols) { INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true); @@ -236,6 +311,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda( INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false); } } +#endif }); }); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 28ba122208..cd0ef6dc7b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3508,9 +3508,14 @@ torch::autograd::variable_list group_index_select_dim0( at::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); + auto res = forward_op.call( - all_indices_input_tensor, static_cast(group_size)); - TORCH_CHECK(res.size() == group_size + 2); + all_indices_input_tensor, static_cast(group_size)); +#ifdef USE_ROCM + TORCH_CHECK(res.size() >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4) +#else + TORCH_CHECK(res.size() == group_size + 2); // CUDA: +2 tensors (1 args, 1 saved_data) +#endif // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3621,7 +3626,12 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); +#ifdef USE_ROCM + TORCH_CHECK(static_cast(result.size()) >= group_size + 2); // ROCm: >= to handle both CPU (+2) and GPU (+4) +#else + TORCH_CHECK(static_cast(result.size()) == group_size + 2); // CUDA: +2 tensors +#endif + ctx->saved_data["group_size"] = group_size; auto [input_group, indices_group] = group_index_select_dim0_unpack(all_indices_input, group_size); @@ -3654,17 +3664,23 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( torch::autograd::variable_list GroupIndexSelectDim0Op::backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { + + auto group_size = ctx->saved_data["group_size"].toInt(); + TORCH_CHECK(static_cast(grad_output_group.size()) >= group_size); + + if (group_size == 0) { // empty outputs return torch::autograd::variable_list(1); } // remove redundant grads - auto group_size = grad_output_group.size() - 2; grad_output_group.resize(group_size); const auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); +#ifdef USE_ROCM + TORCH_CHECK(saved_tensors.size() >= group_size + 3); // ROCm: >= to handle both CPU (+3) and GPU (+5) +#else + TORCH_CHECK(saved_tensors.size() == group_size + 3); // CUDA: indices + 2 tensors + fwd_input +#endif std::vector output_shape_group; int i = 0; while (true) { diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 4a6125b048..005b19fc0c 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -238,7 +238,53 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Total number of int64_t elements required args_ptrs_offsets[NUM_ARGS] = offset; - // Allocate memory for GroupIndexSelectArgs +#ifdef USE_ROCM + // Allocate memory for GroupIndexSelectArgs (ROCm: split small/large) + at::Tensor args_tensor_small = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + at::Tensor args_tensor_large = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensors are contiguous + TORCH_CHECK(args_tensor_small.is_contiguous(), "Tensor args_tensor_small must be contiguous."); + TORCH_CHECK(args_tensor_large.is_contiguous(), "Tensor args_tensor_large must be contiguous."); + + // defining a struct that will maintain the arguments required by + // the GPU kernel + struct SplitArgs { + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + int64_t total_warps = 0; + int64_t count = 0; + } small, large; // small and large structs to hold args for small and large + // tables respectively + + // Offset host pointers + offset_args( + &small.input_ptrs, + &small.output_ptrs, + &small.indices_ptrs, + &small.warp_offsets_group, + &small.num_cols_group, + reinterpret_cast(args_tensor_small.mutable_data_ptr()), + args_ptrs_offsets); + + offset_args( + &large.input_ptrs, + &large.output_ptrs, + &large.indices_ptrs, + &large.warp_offsets_group, + &large.num_cols_group, + reinterpret_cast(args_tensor_large.mutable_data_ptr()), + args_ptrs_offsets); +#else + // Allocate memory for GroupIndexSelectArgs (CUDA: unified) at::Tensor args_tensor = at::empty( {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, at::TensorOptions().dtype(at::kByte).pinned_memory(true)); @@ -263,6 +309,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &num_cols_group, reinterpret_cast(args_tensor.mutable_data_ptr()), args_ptrs_offsets); +#endif auto& first_input = input_group[0]; auto& first_indices = indices_group[0]; @@ -273,10 +320,24 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); const int num_cols = input_reshaped.size(1); const int cols_per_warp = get_group_index_select_cols_per_warp(); + [[maybe_unused]] const int unroll_factor = get_group_index_select_unroll_factor(); + +#ifdef USE_ROCM + int64_t warp_offset = 0; + bool use_var_cols_small = false; + bool use_var_cols_large = false; + + bool first_small_table = true; + bool first_large_table = true; + + int prev_num_cols_small; + int prev_num_cols_large; +#else int64_t warp_offset = 0; bool use_var_cols = false; +#endif // Allocate memory for output_group std::vector output_group; @@ -340,13 +401,9 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warps_needed = warps_per_row * num_output_rows_; } #else - // Standard: One or more warps per row + // CUDA: Standard path only (no small embedding optimization) auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; -#endif // USE_ROCM - - if (num_cols != num_cols_) { - use_var_cols = true; - } +#endif // Create output pointers auto input_shape = input.sizes().vec(); @@ -361,7 +418,45 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.push_back(input.expect_contiguous()); index_contigs.push_back(indices.expect_contiguous()); - // Store args +#ifdef USE_ROCM + if (num_cols_ < cols_per_warp) { + // Optimization for Small Embedding: Pack multiple rows per warp + + if(!first_small_table && num_cols_ != prev_num_cols_small) { + use_var_cols_small = true; + } + first_small_table = false; + prev_num_cols_small = num_cols_; + small.input_ptrs[small.count] = reinterpret_cast(input_contigs[i]->const_data_ptr()); + small.output_ptrs[small.count] = reinterpret_cast(output.mutable_data_ptr()); + small.indices_ptrs[small.count] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + small.num_cols_group[small.count] = num_cols_; + small.warp_offsets_group[small.count] = small.total_warps; + small.total_warps += warps_needed; + small.count++; + } else { + // Standard Embedding: One or more warps per row + + if(!first_large_table && num_cols_ != prev_num_cols_large) { + use_var_cols_large = true; + } + first_large_table = false; + prev_num_cols_large = num_cols_; + + large.input_ptrs[large.count] = reinterpret_cast(input_contigs[i]->const_data_ptr()); + large.output_ptrs[large.count] = reinterpret_cast(output.mutable_data_ptr()); + large.indices_ptrs[large.count] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + large.num_cols_group[large.count] = num_cols_; + large.warp_offsets_group[large.count] = large.total_warps; + large.total_warps += warps_needed; + large.count++; + } +#else + // CUDA: Store args in unified arrays + if (num_cols != num_cols_) { + use_var_cols = true; + } + input_ptrs[i] = reinterpret_cast(input_contigs[i]->const_data_ptr()); output_ptrs[i] = reinterpret_cast(output.mutable_data_ptr()); @@ -370,13 +465,52 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offsets_group[i] = warp_offset; num_cols_group[i] = num_cols_; -#ifdef USE_ROCM - warp_offset += warps_needed; -#else warp_offset += warps_per_row * num_output_rows; -#endif // USE_ROCM +#endif } +#ifdef USE_ROCM + // Store the last offset + if (small.count > 0) { + small.warp_offsets_group[small.count] = small.total_warps; + } + if (large.count > 0) { + large.warp_offsets_group[large.count] = large.total_warps; + } + + // Transfer args tensors to GPU + if(small.count > 0) { + args_tensor_small = args_tensor_small.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &small.input_ptrs, + &small.output_ptrs, + &small.indices_ptrs, + &small.warp_offsets_group, + &small.num_cols_group, + reinterpret_cast(args_tensor_small.mutable_data_ptr()), + args_ptrs_offsets); + } + + if(large.count > 0) { + args_tensor_large = args_tensor_large.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &large.input_ptrs, + &large.output_ptrs, + &large.indices_ptrs, + &large.warp_offsets_group, + &large.num_cols_group, + reinterpret_cast(args_tensor_large.mutable_data_ptr()), + args_ptrs_offsets); + } +#else // Store the last offset warp_offsets_group[group_size] = warp_offset; @@ -394,7 +528,82 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &num_cols_group, reinterpret_cast(args_tensor.mutable_data_ptr()), args_ptrs_offsets); +#endif + +#ifdef USE_ROCM + int64_t saved_data_small[] = { + static_cast(small.count), + use_var_cols_small, + reinterpret_cast(small.warp_offsets_group), + reinterpret_cast(small.num_cols_group), + small.total_warps, + }; + auto saved_data_t_small = at::empty( + {sizeof(saved_data_small) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + + TORCH_CHECK(saved_data_t_small.is_contiguous(), "Tensor saved_data_t_small must be contiguous."); + memcpy(saved_data_t_small.mutable_data_ptr(), saved_data_small, sizeof(saved_data_small)); + + int64_t saved_data_large[] = { + static_cast(large.count), + use_var_cols_large, + reinterpret_cast(large.warp_offsets_group), + reinterpret_cast(large.num_cols_group), + large.total_warps, + }; + auto saved_data_t_large = at::empty( + {sizeof(saved_data_large) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + + TORCH_CHECK(saved_data_t_large.is_contiguous(), "Tensor saved_data_t_large must be contiguous."); + memcpy(saved_data_t_large.mutable_data_ptr(), saved_data_large, sizeof(saved_data_large)); + + if(small.count > 0) { + group_index_select_or_add_cuda( + small.input_ptrs, + small.output_ptrs, + small.indices_ptrs, + small.warp_offsets_group, + small.num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/small.total_warps, + small.count, + /*use_index_select=*/true, + use_var_cols_small, + /*use_small_emb_dim=*/true); + } + + if(large.count > 0) { + group_index_select_or_add_cuda( + large.input_ptrs, + large.output_ptrs, + large.indices_ptrs, + large.warp_offsets_group, + large.num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/large.total_warps, + large.count, + /*use_index_select=*/true, + use_var_cols_large, + /*use_small_emb_dim=*/false); + } + + output_group.push_back(args_tensor_small); + output_group.push_back(args_tensor_large); + output_group.push_back(saved_data_t_small); + output_group.push_back(saved_data_t_large); + // return format: + // (group_size outputs, 2 args_tensors, 2 saved_data) + return output_group; +#else int64_t saved_data[] = { static_cast(group_size), use_var_cols, @@ -431,11 +640,67 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // return format: // (group_size outputs, 1 args_tensor, 1 saved_data) return output_group; +#endif } static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { + +#ifdef USE_ROCM + TORCH_CHECK_VALUE( + all_inputs.size() > 4, + "all_inputs size must be larger than 4, but got ", + all_inputs.size()); + + // all_input size = group_size * 2 (from grads, indices) + // + 2 args_tensors + 2 saved_data + 1 first input + const int64_t group_size = (all_inputs.size() - 5) / 2; + + const Tensor& fwd_input = all_inputs[2 * group_size + 4]; + const Tensor& saved_data_small = all_inputs[2 * group_size + 2]; + const Tensor& saved_data_large = all_inputs[2 * group_size + 3]; + + const Tensor& first_indices = all_inputs[group_size]; + const int64_t output_dim = fwd_input.dim(); + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } + + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK(saved_data_small.device() == at::kCPU, "Tensor saved_data_small must be on CPU."); + TORCH_CHECK(saved_data_small.is_contiguous(), "Tensor saved_data_small must be contiguous."); + const int64_t* saved_data_small_ptr = saved_data_small.const_data_ptr(); + auto count_small = saved_data_small_ptr[0]; + const bool use_var_cols_small = saved_data_small_ptr[1]; + int64_t* warp_offsets_group_small = reinterpret_cast(saved_data_small_ptr[2]); + int32_t* num_cols_group_small = reinterpret_cast(saved_data_small_ptr[3]); + int64_t total_num_warps_small = saved_data_small_ptr[4]; + + TORCH_CHECK(saved_data_large.device() == at::kCPU, "Tensor saved_data_large must be on CPU."); + TORCH_CHECK(saved_data_large.is_contiguous(), "Tensor saved_data_large must be contiguous."); + const int64_t* saved_data_large_ptr = saved_data_large.const_data_ptr(); + auto count_large = saved_data_large_ptr[0]; + const bool use_var_cols_large = saved_data_large_ptr[1]; + int64_t* warp_offsets_group_large = reinterpret_cast(saved_data_large_ptr[2]); + int32_t* num_cols_group_large = reinterpret_cast(saved_data_large_ptr[3]); + int64_t total_num_warps_large = saved_data_large_ptr[4]; + + TORCH_CHECK_VALUE( + (count_small + count_large) == group_size, + "The size of saved_data_small[0] + saved_data_large[0] must match group_size. Expect ", + group_size, + " but got ", + (count_small + count_large)); +#else TORCH_CHECK_VALUE( all_inputs.size() > 2, "all_inputs size must be larger than 2, but got ", @@ -480,6 +745,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); int64_t total_num_warps = saved_data_ptr[4]; +#endif // We checked in forward that all output rows are the same for all member // in the group @@ -502,6 +768,27 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } +#ifdef USE_ROCM + // Allocate Tensors for ptrs of grad output and input, and indices + Tensor args_tensor_small = at::empty( + {count_small * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + + Tensor args_tensor_large = at::empty( + {count_large * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensors are contiguous + TORCH_CHECK(args_tensor_small.is_contiguous(), "Tensor args_tensor_small must be contiguous."); + TORCH_CHECK(args_tensor_large.is_contiguous(), "Tensor args_tensor_large must be contiguous."); + + int64_t* grad_output_ptrs_small = args_tensor_small.mutable_data_ptr(); + int64_t* grad_input_ptrs_small = args_tensor_small.mutable_data_ptr() + count_small; + int64_t* indices_ptrs_small = args_tensor_small.mutable_data_ptr() + 2 * count_small; + + int64_t* grad_output_ptrs_large = args_tensor_large.mutable_data_ptr(); + int64_t* grad_input_ptrs_large = args_tensor_large.mutable_data_ptr() + count_large; + int64_t* indices_ptrs_large = args_tensor_large.mutable_data_ptr() + 2 * count_large; +#else // Allocate Tensor for ptrs of grad output and input, and indices Tensor args_tensor = at::empty( {group_size * 3}, @@ -514,6 +801,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( args_tensor.mutable_data_ptr() + group_size; int64_t* indices_ptrs = args_tensor.mutable_data_ptr() + 2 * group_size; +#endif int64_t group_grad_input_numel = 0; std::vector grad_input_numels; @@ -524,6 +812,12 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( std::vector> grad_output_contigs; grad_output_contigs.reserve(group_size); +#ifdef USE_ROCM + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t idx_small = 0; + int64_t idx_large = 0; +#endif + for (const auto i : c10::irange(group_size)) { const auto& grad = grad_output_group[i]; TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); @@ -539,9 +833,25 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( grad_input_numels.push_back(grad_input_numel); group_grad_input_numel += grad_input_numel; +#ifdef USE_ROCM + const auto grad_reshaped = grad.reshape({grad.size(0), -1}); + const auto num_cols_ = grad_reshaped.size(1); + + // Put all grad output/input pointers in an array + if(num_cols_ < cols_per_warp) { + grad_output_ptrs_small[idx_small] = + reinterpret_cast(grad_output_contigs[i]->const_data_ptr()); + idx_small++; + } else { + grad_output_ptrs_large[idx_large] = + reinterpret_cast(grad_output_contigs[i]->const_data_ptr()); + idx_large++; + } +#else // Put all grad output/input pointers in an array grad_output_ptrs[i] = reinterpret_cast(grad_output_contigs[i]->const_data_ptr()); +#endif } // Allocate a big tensor to avoid calling many small elementwise kernels @@ -561,12 +871,19 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( " but got ", output_group.size()); +#ifdef USE_ROCM + // Reset the counters of the small and large arguments + idx_small = 0; + idx_large = 0; +#endif + // Reshape grad inputs and obtain their pointers for (int i = 0; i < group_size; i++) { const auto grad_input_shape = std::vector( output_shape_group.begin() + i * output_dim, output_shape_group.begin() + (i + 1) * output_dim); output_group[i] = output_group[i].reshape(grad_input_shape); + TORCH_CHECK( output_group[i].is_contiguous(), "Tensor output_group ", @@ -574,23 +891,96 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( " of ", group_size, " must be contiguous."); + +#ifdef USE_ROCM + const auto output_group_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + const auto num_cols_ = output_group_reshaped.size(1); + if(num_cols_ < cols_per_warp) { + grad_input_ptrs_small[idx_small] = reinterpret_cast(output_group[i].const_data_ptr()); + idx_small++; + } else { + grad_input_ptrs_large[idx_large] = reinterpret_cast(output_group[i].const_data_ptr()); + idx_large++; + } +#else grad_input_ptrs[i] = reinterpret_cast(output_group[i].const_data_ptr()); +#endif // 2) Add group_size gradients for inputs outputs.push_back(output_group[i]); } +#ifdef USE_ROCM + // Reset the counters of the small and large arguments + idx_small = 0; + idx_large = 0; +#endif + // Calculate indices_ptrs std::vector> index_contigs; index_contigs.reserve(group_size); for (const auto i : c10::irange(group_size)) { const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); + +#ifdef USE_ROCM + const auto grad_output_reshaped = grad_output_group[i].reshape({grad_output_group[i].size(0), -1}); + const auto num_cols_ = grad_output_reshaped.size(1); + if(num_cols_ < cols_per_warp) { + indices_ptrs_small[idx_small] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + idx_small++; + } else { + indices_ptrs_large[idx_large] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + idx_large++; + } +#else indices_ptrs[i] = reinterpret_cast(index_contigs[i]->const_data_ptr()); +#endif } +#ifdef USE_ROCM + // Transfer grad output pointers to GPU + args_tensor_small = args_tensor_small.to(first_indices.device(), /*non_blocking=*/true); + args_tensor_large = args_tensor_large.to(first_indices.device(), /*non_blocking=*/true); + + if(count_small > 0) { + group_index_select_or_add_cuda( + args_tensor_small.const_data_ptr(), + args_tensor_small.const_data_ptr() + count_small, + args_tensor_small.const_data_ptr() + 2 * count_small, + warp_offsets_group_small, + num_cols_group_small, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps_small, + count_small, + /*use_index_select=*/false, + use_var_cols_small, + /*use_small_emb_dim=*/true); + } + + if(count_large > 0) { + group_index_select_or_add_cuda( + args_tensor_large.data_ptr(), + args_tensor_large.data_ptr() + count_large, + args_tensor_large.data_ptr() + 2 * count_large, + warp_offsets_group_large, + num_cols_group_large, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps_large, + count_large, + /*use_index_select=*/false, + use_var_cols_large, + /*use_small_emb_dim=*/false); + } +#else // Transfer grad output pointers to GPU args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); @@ -608,6 +998,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( group_size, /*use_index_select=*/false, use_var_cols); +#endif return outputs; }