Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions cub/cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ struct AgentSelectIf
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
using ScanTileStateT = ScanTileState<OffsetT>;
using ScanTileStateT = AtomicsBasedTileState<OffsetT>;

// Indicates whether the BlockLoad algorithm uses shared memory to load or exchange the data
static constexpr bool loads_via_smem =
Expand All @@ -222,7 +222,7 @@ struct AgentSelectIf

// If we need to enforce memory order for in-place stream compaction, wrap the default decoupled look-back tile
// state in a helper class that enforces memory order on reads and writes
using MemoryOrderedTileStateT = tile_state_with_memory_order<ScanTileStateT, memory_order>;
using MemoryOrderedTileStateT = ScanTileStateT; // tile_state_with_memory_order<ScanTileStateT, memory_order>;

// The input value type
using InputT = it_value_t<InputIteratorT>;
Expand Down Expand Up @@ -284,7 +284,7 @@ struct AgentSelectIf
// Callback type for obtaining tile prefix during block scan
using DelayConstructorT = typename AgentSelectIfPolicyT::detail::delay_constructor_t;
using TilePrefixCallbackOpT =
TilePrefixCallbackOp<OffsetT, ::cuda::std::plus<>, MemoryOrderedTileStateT, DelayConstructorT>;
AtomicsBasedTilePrefixCallbackOp<OffsetT, ::cuda::std::plus<>, MemoryOrderedTileStateT, DelayConstructorT>;

// Item exchange type
using ItemExchangeT = InputT[TILE_ITEMS];
Expand Down Expand Up @@ -400,7 +400,7 @@ struct AgentSelectIf
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
// Out-of-bounds items are selection_flags
selection_flags[ITEM] = 1;
selection_flags[ITEM] = false;

if (!IS_LAST_TILE || (static_cast<OffsetT>(threadIdx.x * ITEMS_PER_THREAD + ITEM) < num_tile_items))
{
Expand Down Expand Up @@ -429,7 +429,7 @@ struct AgentSelectIf
_CCCL_PRAGMA_UNROLL_FULL()
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
selection_flags[ITEM] = true;
selection_flags[ITEM] = false;
}
// Guarded loads
BlockLoadFlags(temp_storage.load_flags)
Expand Down Expand Up @@ -470,7 +470,7 @@ struct AgentSelectIf
{
// Out-of-bounds items are selection_flags
BlockLoadFlags(temp_storage.load_flags)
.Load((d_flags_in + streaming_context.input_offset()) + tile_offset, flags, num_tile_items, 1);
.Load((d_flags_in + streaming_context.input_offset()) + tile_offset, flags, num_tile_items, 0);
}
else
{
Expand All @@ -496,7 +496,7 @@ struct AgentSelectIf
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
constant_t<USE_DISCONTINUITY> /*select_method*/)
{
if (IS_FIRST_TILE && streaming_context.is_first_partition())
if ((tile_offset == 0 || IS_FIRST_TILE) && streaming_context.is_first_partition())
{
__syncthreads();

Expand Down Expand Up @@ -524,7 +524,7 @@ struct AgentSelectIf
// Set selection_flags for out-of-bounds items
if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items))
{
selection_flags[ITEM] = 1;
selection_flags[ITEM] = 0;
}
}
}
Expand Down Expand Up @@ -857,7 +857,7 @@ struct AgentSelectIf
0,
0,
num_tile_selections,
bool_constant_v < SelectionOpt == SelectImpl::Partition >);
bool_constant_v<SelectionOpt == SelectImpl::Partition>);

return num_tile_selections;
}
Expand Down Expand Up @@ -918,14 +918,6 @@ struct AgentSelectIf
OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix();
OffsetT num_rejected_prefix = tile_offset - num_selections_prefix;

// Discount any out-of-bounds selections
if (IS_LAST_TILE)
{
int num_discount = TILE_ITEMS - num_tile_items;
num_selections -= num_discount;
num_tile_selections -= num_discount;
}

// note (only applies to in-place stream compaction): We can avoid having to introduce explicit memory order between
// the look-back (i.e., loading previous tiles' states) and scattering items (which means, potentially overwriting
// previous tiles' input items, in case of in-place compaction), because this is implicitly ensured through
Expand All @@ -940,7 +932,7 @@ struct AgentSelectIf
num_selections_prefix,
num_rejected_prefix,
num_selections,
bool_constant_v < SelectionOpt == SelectImpl::Partition >);
bool_constant_v<SelectionOpt == SelectImpl::Partition>);

return num_selections;
}
Expand All @@ -966,14 +958,7 @@ struct AgentSelectIf
ConsumeTile(int num_tile_items, int tile_idx, OffsetT tile_offset, MemoryOrderedTileStateT& tile_state_wrapper)
{
OffsetT num_selections;
if (tile_idx == 0)
{
num_selections = ConsumeFirstTile<IS_LAST_TILE>(num_tile_items, tile_offset, tile_state_wrapper);
}
else
{
num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state_wrapper);
}
num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state_wrapper);

return num_selections;
}
Expand All @@ -998,14 +983,15 @@ struct AgentSelectIf
ConsumeRange(int num_tiles, ScanTileStateT& tile_state, NumSelectedIteratorT d_num_selected_out)
{
// Ensure consistent memory order across all tile status updates and loads
auto tile_state_wrapper = MemoryOrderedTileStateT{tile_state};
auto tile_state_wrapper = tile_state;

// Blocks are launched in increasing order, so just assign one tile per block
// TODO (elstehle): replacing this term with just `blockIdx.x` degrades perf for partition. Once we get to re-tune
// the algorithm, we want to replace this term with `blockIdx.x`
int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index
OffsetT tile_offset = static_cast<OffsetT>(tile_idx) * static_cast<OffsetT>(TILE_ITEMS);

OffsetT num_selections;
if (tile_idx < num_tiles - 1)
{
// Not the last tile (full)
Expand All @@ -1014,13 +1000,17 @@ struct AgentSelectIf
else
{
// The last tile (possibly partially-full)
OffsetT num_remaining = num_items - tile_offset;
OffsetT num_selections = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state_wrapper);

if (threadIdx.x == 0)
OffsetT num_remaining = num_items - tile_offset;
num_selections = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state_wrapper);
}
if (threadIdx.x == 0)
{
auto tombstones = tile_state.note_tombstone();
if (tombstones == gridDim.x - 1)
{
// printf("Final tile: %lld\n", (long long)(tile_state.get_aggregate()));
// Update the number of selected items with this partition's selections
streaming_context.update_num_selected(d_num_selected_out, num_selections);
streaming_context.update_num_selected(d_num_selected_out, tile_state.get_aggregate());
}
}
}
Expand Down
200 changes: 194 additions & 6 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,198 @@ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t tile_state_init(
return AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes);
}

template <typename T>
struct AtomicsBasedTileState
{
struct counters_t {
T d_atomic_offset;
uint32_t d_atomic_tombstones;
};

// Device storage
counters_t* d_atomic_counter = nullptr;

/**
* @brief Initializer
*
* @param[in] num_tiles
* Number of tiles. Unused in this implementation.
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage.
* When nullptr, the required allocation size is written to \p temp_storage_bytes and no work is done.
*
* @param[in] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
*/
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t Init(int /*num_tiles*/, void* d_temp_storage, size_t temp_storage_bytes)
{
// Ensure temporary storage allocation is sufficient
if(temp_storage_bytes < sizeof(counters_t))
{
return cudaErrorInvalidValue;
}
d_atomic_counter = reinterpret_cast<counters_t*>(d_temp_storage);

return cudaSuccess;
}

/**
* @brief Compute device memory needed for tile status
*/
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static constexpr cudaError_t
AllocationSize(int /*num_tiles*/, size_t& temp_storage_bytes)
{
temp_storage_bytes = sizeof(counters_t);
return cudaSuccess;
}

/**
* Initialize (from device)
*/
_CCCL_DEVICE _CCCL_FORCEINLINE void InitializeStatus(int /*num_tiles*/)
{
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tile_idx < 1)
{
// printf("InitializeStatus %lld (BID %d, TID %d)\n", (long long)0, blockIdx.x, threadIdx.x);
d_atomic_counter->d_atomic_offset = T{0};
d_atomic_counter->d_atomic_tombstones = T{0};
}
}

/**
* Update the specified tile's inclusive value and corresponding status
*/
_CCCL_DEVICE _CCCL_FORCEINLINE void SetInclusive(int /*tile_idx*/, T tile_inclusive)
{
auto x = atomicAdd(&d_atomic_counter->d_atomic_offset, tile_inclusive);
// printf("set inclusive addL %lld, old: %lld (BID %d, TID %d)\n", (long long) tile_inclusive, (long long)x, blockIdx.x, threadIdx.x);
}

/**
* Update the specified tile's inclusive value and corresponding status
*/
_CCCL_DEVICE _CCCL_FORCEINLINE auto atomic_add( T block_aggregate)
{
auto x = atomicAdd(&d_atomic_counter->d_atomic_offset, block_aggregate);
// printf("atomic_add add: %lld, old: %lld (BID %d, TID %d)\n", (long long)block_aggregate, (long long)x, blockIdx.x, threadIdx.x);
return x;
}

/**
* Update the specified tile's inclusive value and corresponding status
*/
_CCCL_DEVICE _CCCL_FORCEINLINE auto note_tombstone()
{
auto x = atomicAdd(&d_atomic_counter->d_atomic_tombstones, 1);
// printf("note_tombstone add: 1, old: %lld (BID %d, TID %d)\n", (long long)x, blockIdx.x, threadIdx.x);
return x;
}

/**
* Update the specified tile's inclusive value and corresponding status
*/
_CCCL_DEVICE _CCCL_FORCEINLINE auto get_aggregate()
{
return d_atomic_counter->d_atomic_offset;
}
};


/**
* Stateful block-scan prefix functor. Provides the the running prefix for
* the current tile by using the call-back warp to wait on on
* aggregates/prefixes from predecessor tiles to become available.
*
* @tparam DelayConstructorT
* Implementation detail, do not specify directly, requirements on the
* content of this type are subject to breaking change.
*/
template <typename T,
typename ScanOpT,
typename ScanTileStateT,
typename DelayConstructorT = detail::default_delay_constructor_t<T>>
struct AtomicsBasedTilePrefixCallbackOp
{
// Temporary storage type
struct _TempStorage
{
T exclusive_prefix;
T inclusive_prefix;
T block_aggregate;
};

// Alias wrapper allowing temporary storage to be unioned
struct TempStorage : Uninitialized<_TempStorage>
{};

// Fields
_TempStorage& temp_storage; ///< Reference to a warp-reduction instance
ScanTileStateT& tile_status; ///< Interface to tile status
ScanOpT scan_op; ///< Binary scan operator
int tile_idx; ///< The current tile index
T exclusive_prefix; ///< Exclusive prefix for the tile
T inclusive_prefix; ///< Inclusive prefix for the tile

// Constructs prefix functor for a given tile index.
// Precondition: thread blocks processing all of the predecessor tiles were scheduled.
_CCCL_DEVICE _CCCL_FORCEINLINE
AtomicsBasedTilePrefixCallbackOp(ScanTileStateT& tile_status, TempStorage& temp_storage, ScanOpT scan_op, int tile_idx)
: temp_storage(temp_storage.Alias())
, tile_status(tile_status)
, scan_op(scan_op)
, tile_idx(tile_idx)
{}

// Computes the tile index and constructs prefix functor with it.
// Precondition: thread block per tile assignment.
_CCCL_DEVICE _CCCL_FORCEINLINE
AtomicsBasedTilePrefixCallbackOp(ScanTileStateT& tile_status, TempStorage& temp_storage, ScanOpT scan_op)
: AtomicsBasedTilePrefixCallbackOp(tile_status, temp_storage, scan_op, blockIdx.x)
{}

// BlockScan prefix callback functor (called by the first warp)
_CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T block_aggregate)
{
// Compute the inclusive tile prefix and update the status for this tile
T thread_exclusive_prefix{};
if (threadIdx.x == 0)
{
thread_exclusive_prefix = tile_status.atomic_add(block_aggregate);
exclusive_prefix = thread_exclusive_prefix;
inclusive_prefix = thread_exclusive_prefix + block_aggregate;
temp_storage.block_aggregate = block_aggregate;
temp_storage.exclusive_prefix = exclusive_prefix;
temp_storage.inclusive_prefix = inclusive_prefix;
}

// Broadcast exclusive_prefix to other threads
exclusive_prefix = __shfl_sync(0xffffffff, exclusive_prefix, 0, 32);

// Return exclusive_prefix
return exclusive_prefix;
}

// Get the exclusive prefix stored in temporary storage
_CCCL_DEVICE _CCCL_FORCEINLINE T GetExclusivePrefix()
{
return temp_storage.exclusive_prefix;
}

// Get the inclusive prefix stored in temporary storage
_CCCL_DEVICE _CCCL_FORCEINLINE T GetInclusivePrefix()
{
return temp_storage.inclusive_prefix;
}

// Get the block aggregate stored in temporary storage
_CCCL_DEVICE _CCCL_FORCEINLINE T GetBlockAggregate()
{
return temp_storage.block_aggregate;
}
};

} // namespace detail

/**
Expand Down Expand Up @@ -645,7 +837,6 @@ struct ScanTileState<T, true>
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage.
* When nullptr, the required allocation size is written to \p temp_storage_bytes and no work is
* done.
*
* @param[in] temp_storage_bytes
Expand Down Expand Up @@ -849,9 +1040,7 @@ struct ScanTileState<T, false>
* Number of tiles
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage.
* When nullptr, the required allocation size is written to \p temp_storage_bytes and no work is
* done.
* Device-accessible allocation of temporary storage. When nullptr, no work is done.
*
* @param[in] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
Expand Down Expand Up @@ -1061,8 +1250,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
* Number of tiles
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage. When nullptr, the required allocation size
* is written to \p temp_storage_bytes and no work is done.
* Device-accessible allocation of temporary storage. When nullptr, no work is done.
*
* @param[in] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
Expand Down
3 changes: 2 additions & 1 deletion cub/cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ struct DispatchSelectIf

using streaming_context_t = detail::select::streaming_context_t<num_total_items_t, may_require_streaming>;

using ScanTileStateT = ScanTileState<per_partition_offset_t>;
// using ScanTileStateT = ScanTileState<per_partition_offset_t>;
using ScanTileStateT = detail::AtomicsBasedTileState<per_partition_offset_t>;

static constexpr int INIT_KERNEL_THREADS = 128;

Expand Down
Loading
Loading