Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ad1c1df
Base changes in scan and tests.
griwes Feb 8, 2026
6371339
Update benchmarks.
griwes Feb 8, 2026
9e346b5
Update copyright years.
griwes Feb 8, 2026
910d511
Merge remote-tracking branch 'origin/main' into feature/new-tuning-ap…
griwes Feb 8, 2026
e9467af
c.parallel: centralize the handling of common cub types.
griwes Feb 8, 2026
c4c0c09
Resolve review comments.
griwes Feb 13, 2026
2c2db7c
Fix c.parallel radix_sort breakage.
griwes Feb 13, 2026
a288da0
integrate warpspeed: Merge remote-tracking branch 'origin/main' into …
griwes Feb 18, 2026
9eec3e2
Merge remote-tracking branch 'origin/main' into feature/new-tuning-ap…
griwes Feb 18, 2026
2a0ddf4
Compilation fixes.
griwes Feb 19, 2026
22ece56
Go through dispatch_arch, unify dispatch paths for scan.
griwes Feb 19, 2026
d7f5333
Remove cuda::std::optional from policies.
griwes Feb 19, 2026
497638c
Pull scan_warpspeed_policy out into its own file.
griwes Feb 19, 2026
61db5eb
Check for is_constant_evaluated in new dispatch.
griwes Feb 19, 2026
5f3aeda
Fix some thinkos.
griwes Feb 20, 2026
72239e2
Compilation fixes.
griwes Feb 20, 2026
861b25c
Remove %RANGE% declarations from header
bernhardmgruber Feb 25, 2026
0a7d3e7
Add delay_constructor_policy to look_back_helper.cuh
bernhardmgruber Feb 25, 2026
2a7e044
CI fixes.
griwes Feb 26, 2026
1353f06
Refactor DeviceScan, remove warpspeed from policy_hub, address review…
griwes Feb 26, 2026
863c874
Operators for scan_warpspeed_policy.
griwes Feb 26, 2026
6867500
Merge remote-tracking branch 'origin/main' into feature/new-tuning-ap…
griwes Feb 26, 2026
699d4bf
Fix clang build.
griwes Feb 26, 2026
ab85bfe
Fix a missed test.
griwes Feb 26, 2026
372bcb9
Fix clang-cuda concept checks.
griwes Feb 26, 2026
27ab1c4
Fix classifications of bool and min/max.
griwes Feb 26, 2026
16bc452
Add missing includes.
griwes Feb 27, 2026
aec8147
Merge remote-tracking branch 'origin/main' into feature/new-tuning-ap…
griwes Mar 5, 2026
565a017
Codegen fixes.
griwes Mar 13, 2026
9cd3ca0
Review comments.
griwes Mar 13, 2026
f6af88d
Merge remote-tracking branch 'origin/main' into feature/new-tuning-ap…
griwes Mar 13, 2026
81b7a7f
More abstraction layers to restore constexprness.
griwes Mar 13, 2026
029b195
Correctly check for the constants.
griwes Mar 13, 2026
ac03691
Another abstraction layer, to remove a constexpr reference to `this`.
griwes Mar 13, 2026
a6ed3cd
I kinda hate this but I think it has to be like this.
griwes Mar 13, 2026
3bb8169
Silence a warning.
griwes Mar 13, 2026
5dbcfd6
Silence MSVC unreachable code warning.
griwes Mar 13, 2026
203e807
Merge remote-tracking branch 'origin/main' into feature/new-tuning-ap…
griwes Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c/parallel/include/cccl/c/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ typedef struct cccl_device_scan_build_result_t
CUkernel scan_kernel;
bool force_inclusive;
cccl_init_kind_t init_kind;
bool use_warpspeed;
size_t description_bytes_per_tile;
size_t payload_bytes_per_tile;
void* runtime_policy;
Expand Down
17 changes: 4 additions & 13 deletions c/parallel/src/radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -223,18 +223,7 @@ try
std::string offset_t;
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

// TODO(bgruber): generalize this somewhere
const auto key_type = [&] {
switch (input_keys_it.value_type.type)
{
case CCCL_FLOAT32:
return cub::detail::type_t::float32;
case CCCL_FLOAT64:
return cub::detail::type_t::float64;
default:
return cub::detail::type_t::other;
}
}();
const auto key_type = cccl_type_enum_to_cub_type(input_keys_it.value_type.type);

const auto policy_sel = cub::detail::radix_sort::policy_selector{
static_cast<int>(input_keys_it.value_type.size),
Expand Down Expand Up @@ -268,6 +257,8 @@ using device_radix_sort_policy = {5};
using namespace cub;
using namespace cub::detail;
using namespace cub::detail::radix_sort;
using cub::detail::delay_constructor_policy;
using cub::detail::delay_constructor_kind;
static_assert(device_radix_sort_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {6}, "Host generated and JIT compiled policy mismatch");
)XXX",
input_keys_it.value_type.size, // 0
Expand Down
29 changes: 3 additions & 26 deletions c/parallel/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -223,31 +223,8 @@ try
const auto policy_sel = [&] {
using namespace cub::detail;

auto accum_type = type_t::other;
if (accum_t.type == CCCL_FLOAT32)
{
accum_type = type_t::float32;
}
else if (accum_t.type == CCCL_FLOAT64)
{
accum_type = type_t::float64;
}

auto operation_t = op_kind_t::other;
switch (op.type)
{
case CCCL_PLUS:
operation_t = op_kind_t::plus;
break;
case CCCL_MINIMUM:
operation_t = op_kind_t::min;
break;
case CCCL_MAXIMUM:
operation_t = op_kind_t::max;
break;
default:
break;
}
const auto accum_type = cccl_type_enum_to_cub_type(accum_t.type);
const auto operation_t = cccl_op_kind_to_cub_op(op.type);

const int offset_size = int{sizeof(OffsetT)};
return cub::detail::reduce::policy_selector{accum_type, operation_t, offset_size, static_cast<int>(accum_t.size)};
Expand Down
195 changes: 106 additions & 89 deletions c/parallel/src/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#include <cub/detail/choose_offset.cuh>
#include <cub/detail/launcher/cuda_driver.cuh>
#include <cub/detail/ptx-json-parser.cuh>
#include <cub/device/dispatch/dispatch_scan.cuh>
#include <cub/thread/thread_load.cuh>
#include <cub/util_arch.cuh>
Expand All @@ -21,6 +20,7 @@
#include <format>
#include <iostream>
#include <optional>
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
Expand Down Expand Up @@ -56,33 +56,6 @@ enum class InitKind
NoInit,
};

struct scan_runtime_tuning_policy
{
cub::detail::RuntimeScanAgentPolicy scan;

auto Scan() const
{
return scan;
}

void CheckLoadModifier() const
{
if (scan.LoadModifier() == cub::CacheLoadModifier::LOAD_LDG)
{
throw std::runtime_error("The memory consistency model does not apply to texture "
"accesses");
}
}

using MaxPolicy = scan_runtime_tuning_policy;

template <typename F>
cudaError_t Invoke(int, F& op)
{
return op.template Invoke<scan_runtime_tuning_policy>(*this);
}
};

static cccl_type_info get_accumulator_type(cccl_op_t /*op*/, cccl_iterator_t /*input_it*/, cccl_type_info init)
{
// TODO Should be decltype(op(init, *input_it)) but haven't implemented type arithmetic yet
Expand Down Expand Up @@ -135,8 +108,8 @@ std::string get_scan_kernel_name(
bool force_inclusive,
cccl_init_kind_t init_kind)
{
std::string chained_policy_t;
check(cccl_type_name_from_nvrtc<device_scan_policy>(&chained_policy_t));
std::string policy_selector_t;
check(cccl_type_name_from_nvrtc<device_scan_policy>(&policy_selector_t));

const cccl_type_info accum_t = scan::get_accumulator_type(op, input_it, init);
const std::string accum_cpp_t = cccl_type_enum_to_name(accum_t.type);
Expand Down Expand Up @@ -177,7 +150,7 @@ std::string get_scan_kernel_name(
auto tile_state_t = std::format("cub::ScanTileState<{0}>", accum_cpp_t);
return std::format(
"cub::detail::scan::DeviceScanKernel<{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}>",
chained_policy_t, // 0
policy_selector_t, // 0
input_iterator_t, // 1
output_iterator_t, // 2
tile_state_t, // 3
Expand All @@ -189,20 +162,6 @@ std::string get_scan_kernel_name(
init_t); // 9
}

template <auto* GetPolicy>
struct dynamic_scan_policy_t
{
using MaxPolicy = dynamic_scan_policy_t;

template <typename F>
cudaError_t Invoke(int device_ptx_version, F& op)
{
return op.template Invoke<scan_runtime_tuning_policy>(GetPolicy(device_ptx_version, accumulator_type));
}

cccl_type_info accumulator_type;
};

struct scan_kernel_source
{
cccl_device_scan_build_result_t& build;
Expand All @@ -219,11 +178,16 @@ struct scan_kernel_source
{
return build.scan_kernel;
}
scan_tile_state TileState()
scan_tile_state TileState() const
{
return {build.description_bytes_per_tile, build.payload_bytes_per_tile};
}

bool use_warpspeed(const cub::detail::scan::scan_policy& /*policy*/) const
{
return build.use_warpspeed;
}

std::size_t look_ahead_tile_state_size() const
{
return look_ahead_tile_state_alignment();
Expand Down Expand Up @@ -287,8 +251,77 @@ try

const auto output_it_value_t = cccl_type_enum_to_name(output_it.value_type.type);

std::string policy_hub_expr = std::format(
"cub::detail::scan::policy_hub<{}, {}, {}, {}, {}>",
const auto policy_sel = [&] {
using cub::detail::scan::policy_selector;
using cub::detail::scan::primitive_accum;
using cub::detail::scan::primitive_op;

const auto accum_type = cccl_type_enum_to_cub_type(accum_t.type);
const auto operation_t = cccl_op_kind_to_cub_op(op.type);
const auto input_type = input_it.value_type.type;
const auto input_type_t = cccl_type_enum_to_cub_type(input_type);

const auto output_type = output_it.value_type.type;
const bool types_match = input_type == output_type && input_type == accum_t.type;
const bool benchmark_match =
operation_t != cub::detail::op_kind_t::other && types_match && input_type != CCCL_STORAGE;
const bool accum_is_primitive_or_trivially_copy_constructible = true;

return policy_selector{
static_cast<int>(input_it.value_type.size),
static_cast<int>(input_it.value_type.alignment),
static_cast<int>(output_it.value_type.size),
static_cast<int>(output_it.value_type.alignment),
static_cast<int>(accum_t.size),
static_cast<int>(accum_t.alignment),
int{sizeof(OffsetT)},
input_type_t,
accum_type,
operation_t,
accum_is_primitive_or_trivially_copy_constructible,
benchmark_match};
}();

const auto arch_id = cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor});
const auto active_policy = policy_sel(arch_id);

#if _CCCL_CUDACC_AT_LEAST(12, 8)
const auto is_trivial_type = [](cccl_type_enum /* type */) {
// TODO: implement actual logic here when nontrivial custom types become supported
return true;
};

const bool input_contiguous = input_it.type == cccl_iterator_kind_t::CCCL_POINTER;
const bool output_contiguous = output_it.type == cccl_iterator_kind_t::CCCL_POINTER;
const bool input_trivially_copyable = is_trivial_type(input_it.value_type.type);
const bool output_trivially_copyable = is_trivial_type(output_it.value_type.type);
const bool output_default_constructible = output_trivially_copyable;

const bool use_warpspeed =
active_policy.warpspeed
&& cub::detail::scan::use_warpspeed(
active_policy.warpspeed,
static_cast<int>(input_it.value_type.size),
static_cast<int>(input_it.value_type.alignment),
static_cast<int>(output_it.value_type.size),
static_cast<int>(output_it.value_type.alignment),
static_cast<int>(accum_t.size),
static_cast<int>(accum_t.alignment),
input_contiguous,
output_contiguous,
input_trivially_copyable,
output_trivially_copyable,
output_default_constructible);
#else
const bool use_warpspeed = false;
#endif

// TODO(bgruber): drop this if tuning policies become formattable
std::stringstream policy_sel_str;
policy_sel_str << active_policy;

std::string policy_selector_expr = std::format(
"cub::detail::scan::policy_selector_from_types<{}, {}, {}, {}, {}>",
input_it_value_t,
output_it_value_t,
accum_cpp,
Expand All @@ -307,20 +340,20 @@ struct __align__({1}) storage_t {{
{2}
{3}
{4}
using device_scan_policy = {5}::MaxPolicy;

#include <cub/detail/ptx-json/json.cuh>
__device__ consteval auto& policy_generator() {{
return ptx_json::id<ptx_json::string("device_scan_policy")>()
= cub::detail::scan::ScanPolicyWrapper<device_scan_policy::ActivePolicy>::EncodedPolicy();
}}
using device_scan_policy = {5};
using namespace cub;
using namespace cub::detail::scan;
using cub::detail::delay_constructor_policy;
using cub::detail::delay_constructor_kind;
static_assert(device_scan_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {6}, "Host generated and JIT compiled policy mismatch");
)XXX",
input_it.value_type.size, // 0
input_it.value_type.alignment, // 1
input_iterator_src, // 2
output_iterator_src, // 3
op_src, // 4
policy_hub_expr); // 5
policy_selector_expr, // 5
policy_sel_str.view()); // 6

#if false // CCCL_DEBUGGING_SWITCH
fflush(stderr);
Expand All @@ -344,7 +377,6 @@ __device__ consteval auto& policy_generator() {{
"-rdc=true",
"-dlto",
"-DCUB_DISABLE_CDP",
"-DCUB_ENABLE_POLICY_PTX_JSON",
"-std=c++20"};

cccl::detail::extend_args_with_build_config(args, config);
Expand Down Expand Up @@ -379,11 +411,6 @@ __device__ consteval auto& policy_generator() {{
auto [description_bytes_per_tile,
payload_bytes_per_tile] = get_tile_state_bytes_per_tile(accum_t, accum_cpp, args.data(), args.size(), arch);

nlohmann::json runtime_policy = cub::detail::ptx_json::parse("device_scan_policy", {result.data.get(), result.size});

using cub::detail::RuntimeScanAgentPolicy;
auto scan_policy = RuntimeScanAgentPolicy::from_json(runtime_policy, "ScanPolicyT");

build_ptr->cc = cc;
build_ptr->cubin = (void*) result.data.release();
build_ptr->cubin_size = result.size;
Expand All @@ -392,7 +419,8 @@ __device__ consteval auto& policy_generator() {{
build_ptr->init_kind = init_kind;
build_ptr->description_bytes_per_tile = description_bytes_per_tile;
build_ptr->payload_bytes_per_tile = payload_bytes_per_tile;
build_ptr->runtime_policy = new scan::scan_runtime_tuning_policy{scan_policy};
build_ptr->runtime_policy = new cub::detail::scan::policy_selector{policy_sel};
build_ptr->use_warpspeed = use_warpspeed;

return CUDA_SUCCESS;
}
Expand Down Expand Up @@ -426,30 +454,18 @@ CUresult cccl_device_scan(
CUdevice cu_device;
check(cuCtxGetDevice(&cu_device));

auto exec_status = cub::DispatchScan<
indirect_arg_t,
indirect_arg_t,
indirect_arg_t,
std::conditional_t<std::is_same_v<InitValueT, cub::NullType>, cub::NullType, indirect_arg_t>,
cuda::std::size_t,
void,
EnforceInclusive,
scan::scan_runtime_tuning_policy,
scan::scan_kernel_source,
cub::detail::CudaDriverLauncherFactory>::
Dispatch(
d_temp_storage,
*temp_storage_bytes,
d_in,
d_out,
op,
init,
num_items,
stream,
{build},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
*reinterpret_cast<scan::scan_runtime_tuning_policy*>(build.runtime_policy));

auto exec_status = cub::detail::scan::dispatch_with_accum<void, EnforceInclusive>(
d_temp_storage,
*temp_storage_bytes,
indirect_arg_t{d_in},
indirect_arg_t{d_out},
indirect_arg_t{op},
std::conditional_t<std::is_same_v<InitValueT, cub::NullType>, cub::NullType, indirect_arg_t>{init},
static_cast<OffsetT>(num_items),
stream,
*static_cast<cub::detail::scan::policy_selector*>(build.runtime_policy),
scan::scan_kernel_source{build},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});
error = static_cast<CUresult>(exec_status);
}
catch (const std::exception& exc)
Expand Down Expand Up @@ -591,7 +607,8 @@ try
return CUDA_ERROR_INVALID_VALUE;
}
std::unique_ptr<char[]> cubin(reinterpret_cast<char*>(build_ptr->cubin));
std::unique_ptr<char[]> policy(reinterpret_cast<char*>(build_ptr->runtime_policy));
std::unique_ptr<cub::detail::scan::policy_selector> policy(
static_cast<cub::detail::scan::policy_selector*>(build_ptr->runtime_policy));
check(cuLibraryUnload(build_ptr->library));

return CUDA_SUCCESS;
Expand Down
Loading
Loading