Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions c/parallel/test/test_segmented_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cstdlib>
#include <optional> // std::optional
#include <string>
#include <type_traits>
#include <vector>

#include <cuda_runtime.h>
Expand Down Expand Up @@ -72,8 +73,14 @@ auto& get_cache()
return fixture<segmented_sort_build_cache_t, Tag>::get_or_create().get_value();
}

template <bool DisableSassCheckOnSm120 = false>
struct segmented_sort_build
{
static bool should_check_sass(int cc_major)
{
return !(DisableSassCheckOnSm120 && cc_major >= 12);
}

CUresult operator()(
BuildResultT* build_ptr,
cccl_sort_order_t sort_order,
Expand Down Expand Up @@ -144,7 +151,9 @@ struct segmented_sort_run
}
};

template <typename BuildCache = segmented_sort_build_cache_t, typename KeyT = std::string>
template <bool DisableSassCheckOnSm120 = false,
typename BuildCache = segmented_sort_build_cache_t,
typename KeyT = std::string>
void segmented_sort(
cccl_sort_order_t sort_order,
cccl_iterator_t keys_in,
Expand All @@ -160,7 +169,12 @@ void segmented_sort(
std::optional<BuildCache>& cache,
const std::optional<KeyT>& lookup_key)
{
AlgorithmExecute<BuildResultT, segmented_sort_build, segmented_sort_cleanup, segmented_sort_run, BuildCache, KeyT>(
AlgorithmExecute<BuildResultT,
segmented_sort_build<DisableSassCheckOnSm120>,
segmented_sort_cleanup,
segmented_sort_run,
BuildCache,
KeyT>(
cache,
lookup_key,
sort_order,
Expand All @@ -186,10 +200,11 @@ C2H_TEST("segmented_sort can sort keys-only", "[segmented_sort][keys_only]", tes
using T = c2h::get<0, TestType>;
using key_t = typename T::KeyT;

constexpr auto this_test_params = T();
constexpr bool is_descending = this_test_params.is_descending();
constexpr auto order = is_descending ? CCCL_DESCENDING : CCCL_ASCENDING;
constexpr bool is_overwrite_okay = this_test_params.is_overwrite_okay();
constexpr auto this_test_params = T();
constexpr bool is_descending = this_test_params.is_descending();
constexpr auto order = is_descending ? CCCL_DESCENDING : CCCL_ASCENDING;
constexpr bool is_overwrite_okay = this_test_params.is_overwrite_okay();
constexpr bool disable_sass_check_on_sm120 = std::is_same_v<key_t, c2h::get<3, key_types>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I think I just hit this locally today :D


const std::size_t n_segments = GENERATE(0, 13, take(2, random(1 << 10, 1 << 12)));
const std::size_t segment_size = GENERATE(1, 12, take(2, random(1 << 10, 1 << 12)));
Expand Down Expand Up @@ -272,7 +287,7 @@ C2H_TEST("segmented_sort can sort keys-only", "[segmented_sort][keys_only]", tes

int selector = -1;

segmented_sort(
segmented_sort<disable_sass_check_on_sm120>(
order,
keys_in_ptr,
keys_out_ptr,
Expand Down Expand Up @@ -315,10 +330,11 @@ C2H_TEST("segmented_sort can sort key-value pairs", "[segmented_sort][key_value]
using T = c2h::get<0, TestType>;
using key_t = typename T::KeyT;

constexpr auto this_test_params = T();
constexpr bool is_descending = this_test_params.is_descending();
constexpr auto order = is_descending ? CCCL_DESCENDING : CCCL_ASCENDING;
constexpr bool is_overwrite_okay = this_test_params.is_overwrite_okay();
constexpr auto this_test_params = T();
constexpr bool is_descending = this_test_params.is_descending();
constexpr auto order = is_descending ? CCCL_DESCENDING : CCCL_ASCENDING;
constexpr bool is_overwrite_okay = this_test_params.is_overwrite_okay();
constexpr bool disable_sass_check_on_sm120 = !std::is_same_v<key_t, c2h::get<0, key_types>>;

const std::size_t n_segments = GENERATE(0, 13, take(2, random(1 << 10, 1 << 12)));
const std::size_t segment_size = GENERATE(1, 12, take(2, random(1 << 10, 1 << 12)));
Expand Down Expand Up @@ -371,7 +387,7 @@ C2H_TEST("segmented_sort can sort key-value pairs", "[segmented_sort][key_value]

int selector = -1;

segmented_sort(
segmented_sort<disable_sass_check_on_sm120>(
order,
keys_in_ptr,
keys_out_ptr,
Expand Down Expand Up @@ -583,6 +599,8 @@ C2H_TEST("SegmentedSort works with variable segment sizes", "[segmented_sort][va
constexpr bool is_descending = this_test_params.is_descending();
constexpr auto order = is_descending ? CCCL_DESCENDING : CCCL_ASCENDING;
constexpr bool is_overwrite_okay = this_test_params.is_overwrite_okay();
constexpr bool disable_sass_check_on_sm120 =
std::is_same_v<key_t, c2h::get<1, key_types>> || std::is_same_v<key_t, c2h::get<2, key_types>>;

const std::size_t n_segments = GENERATE(20, 600);

Expand Down Expand Up @@ -644,7 +662,7 @@ C2H_TEST("SegmentedSort works with variable segment sizes", "[segmented_sort][va

int selector = -1;

segmented_sort(
segmented_sort<disable_sass_check_on_sm120>(
order,
keys_in_ptr,
keys_out_ptr,
Expand Down
30 changes: 19 additions & 11 deletions c/parallel/test/test_three_way_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ auto& get_cache()
return fixture<three_way_partition_build_cache_t, Tag>::get_or_create().get_value();
}

template <bool DisableSassCheck = false>
template <bool DisableSassCheck = false, bool DisableSassCheckOnSm120 = false>
struct three_way_partition_build
{
template <typename... Rest>
Expand Down Expand Up @@ -77,9 +77,9 @@ struct three_way_partition_build
rest...);
}

static constexpr bool should_check_sass(int)
static constexpr bool should_check_sass(int cc_major)
{
return !DisableSassCheck;
return !DisableSassCheck && !(DisableSassCheckOnSm120 && cc_major >= 12);
}
};

Expand Down Expand Up @@ -200,7 +200,12 @@ std_partition(FirstPartSelectionOp first_selector, SecondPartSelectionOp second_
return result;
}

template <typename OperationT, typename KeyT, typename NumSelectedT, typename TagT, bool DisableSassCheck = false>
template <typename OperationT,
typename KeyT,
typename NumSelectedT,
typename TagT,
bool DisableSassCheck = false,
bool DisableSassCheckOnSm120 = false>
three_way_partition_result_t<KeyT>
c_parallel_partition(OperationT first_selector, OperationT second_selector, const std::vector<KeyT>& input)
{
Expand All @@ -215,7 +220,7 @@ c_parallel_partition(OperationT first_selector, OperationT second_selector, cons
auto& build_cache = get_cache<TagT>();
const auto& test_key = make_key<KeyT, NumSelectedT>();

three_way_partition<DisableSassCheck>(
three_way_partition<DisableSassCheck, DisableSassCheckOnSm120>(
input_ptr,
first_part_output_ptr,
second_part_output_ptr,
Expand All @@ -241,9 +246,10 @@ c_parallel_partition(OperationT first_selector, OperationT second_selector, cons
num_items - num_selected[0] - num_selected[1]);
}

template <bool DisableSassCheck = false,
typename BuildCache = three_way_partition_build_cache_t,
typename KeyT = std::string>
template <bool DisableSassCheck = false,
bool DisableSassCheckOnSm120 = false,
typename BuildCache = three_way_partition_build_cache_t,
typename KeyT = std::string>
void three_way_partition(
cccl_iterator_t d_in,
cccl_iterator_t d_first_part_out,
Expand All @@ -257,7 +263,7 @@ void three_way_partition(
const std::optional<KeyT>& lookup_key)
{
AlgorithmExecute<BuildResultT,
three_way_partition_build<DisableSassCheck>,
three_way_partition_build<DisableSassCheck, DisableSassCheckOnSm120>,
three_way_partition_cleanup,
three_way_partition_run,
BuildCache,
Expand Down Expand Up @@ -362,7 +368,9 @@ extern "C" __device__ void greater_or_equal_op(void* state_ptr, void* x_ptr, voi
c_parallel_partition<stateful_operation_t<selector_state_t>,
key_t,
num_selected_t,
ThreeWayPartition_StatefulOperations_Fixture_Tag>(less_op, greater_or_equal_op, input);
ThreeWayPartition_StatefulOperations_Fixture_Tag,
false,
true>(less_op, greater_or_equal_op, input);
auto std_result = std_partition(less_than_t<key_t>{key_t{21}}, greater_or_equal_t<key_t>{key_t{21}}, input);

REQUIRE(c_parallel_result == std_result);
Expand Down Expand Up @@ -491,7 +499,7 @@ C2H_TEST("ThreeWayPartition works with iterators", "[three_way_partition]")
auto& build_cache = get_cache<ThreeWayPartition_Iterators_Fixture_Tag>();
const auto& test_key = make_key<key_t, num_selected_t>();

three_way_partition(
three_way_partition<false, true>(
input_it,
first_part_output_it,
second_part_output_it,
Expand Down
6 changes: 6 additions & 0 deletions ci/matrix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ workflows:
# c.parallel -- pinned to gcc13 on Linux to match python
- {jobs: ['test'], project: 'cccl_c_parallel', ctk: '12.X', cxx: ['gcc13', 'msvc'], gpu: ['rtx2080']}
- {jobs: ['test'], project: 'cccl_c_parallel', ctk: '13.X', cxx: ['gcc13', 'msvc'], gpu: ['rtx2080', 'l4', 'h100']}
# RTX PRO 6000 coverage (limited due to small number of runners):
- {jobs: ['test'], project: 'cccl_c_parallel', ctk: '13.X', cxx: ['gcc13'], gpu: ['rtxpro6000']}
# c.experimental.stf-- pinned to gcc13 to match python
- {jobs: ['test'], project: 'cccl_c_stf', ctk: '12.X', cxx: 'gcc13', gpu: ['rtx2080']}
- {jobs: ['test'], project: 'cccl_c_stf', ctk: '13.X', cxx: 'gcc13', gpu: ['rtx2080', 'l4', 'h100']}
Expand Down Expand Up @@ -182,6 +184,8 @@ workflows:
# c.parallel -- pinned to gcc13 to match python
- {jobs: ['test'], project: ['cccl_c_parallel'], ctk: '12.X', cxx: ['gcc13', 'msvc'], gpu: ['rtx2080']}
- {jobs: ['test'], project: ['cccl_c_parallel'], ctk: '13.X', cxx: ['gcc13', 'msvc'], gpu: ['rtx2080', 'l4', 'h100']}
# RTX PRO 6000 coverage (limited due to small number of runners):
- {jobs: ['test'], project: 'cccl_c_parallel', ctk: '13.X', cxx: ['gcc13'], gpu: ['rtxpro6000']}
# c.experimental.stf -- pinned to gcc13 to match python
- {jobs: ['test'], project: ['cccl_c_stf'], ctk: '12.X', cxx: 'gcc13', gpu: ['rtx2080']}
- {jobs: ['test'], project: ['cccl_c_stf'], ctk: '13.X', cxx: 'gcc13', gpu: ['rtx2080', 'l4', 'h100']}
Expand Down Expand Up @@ -267,6 +271,8 @@ workflows:
# c.parallel -- pinned to gcc13 to match python
- {jobs: ['test'], project: ['cccl_c_parallel'], ctk: '12.X', cxx: ['gcc13', 'msvc'], gpu: ['rtx2080']}
- {jobs: ['test'], project: ['cccl_c_parallel'], ctk: '13.X', cxx: ['gcc13', 'msvc'], gpu: ['rtx2080', 'l4', 'h100']}
# RTX PRO 6000 coverage (limited due to small number of runners):
- {jobs: ['test'], project: 'cccl_c_parallel', ctk: '13.X', cxx: ['gcc13'], gpu: ['rtxpro6000']}
# c.experimental.stf -- pinned to gcc13 to match python
- {jobs: ['test'], project: ['cccl_c_stf'], ctk: '12.X', cxx: 'gcc13', gpu: ['rtx2080']}
- {jobs: ['test'], project: ['cccl_c_stf'], ctk: '13.X', cxx: 'gcc13', gpu: ['rtx2080', 'l4', 'h100']}
Expand Down
Loading