Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
4915005
Add kernelPath and includeDir to the ctor
KaoCC Sep 7, 2023
b725f75
Add missing header
KaoCC Sep 8, 2023
117a318
fix template constexpr rule
KaoCC Sep 8, 2023
571111b
Use default values for bitcode
KaoCC Sep 8, 2023
c9bb91b
[ORO-0] simple porting
AtsushiYoshimura0302 Sep 14, 2023
12fd48c
[ORO-0] fixed storage ver
AtsushiYoshimura0302 Sep 21, 2023
284b083
[ORO-0] Memset can be skipped in most situations
AtsushiYoshimura0302 Sep 21, 2023
d12fdea
[ORO-0] fix the wrong condition for MAX_LOOK_BACK
AtsushiYoshimura0302 Sep 21, 2023
8be62ff
Merge remote-tracking branch 'origin/main' into feature/ORO-0-radixso…
AtsushiYoshimura0302 Sep 22, 2023
688be9a
[ORO-0]fix wrong address for counter - it was luckly working
AtsushiYoshimura0302 Sep 22, 2023
22e1790
Merge remote-tracking branch 'origin/main' into feature/ORO-0-radixso…
AtsushiYoshimura0302 Sep 26, 2023
861039e
Merge remote-tracking branch 'origin/main' into feature/ORO-0-radixso…
AtsushiYoshimura0302 Dec 25, 2023
a360186
use resizeAsync
AtsushiYoshimura0302 Dec 25, 2023
1bb319a
remove inl dependency
AtsushiYoshimura0302 Dec 25, 2023
cc8ad0e
constexpr noexcept for helper funcs
AtsushiYoshimura0302 Dec 25, 2023
e9b33f8
Use GPU timer
AtsushiYoshimura0302 Dec 25, 2023
8681ee0
other test variants
AtsushiYoshimura0302 Dec 25, 2023
f6de36d
Split iterators
AtsushiYoshimura0302 Dec 25, 2023
ed8fb9d
Split temp buffer for simplicity
AtsushiYoshimura0302 Dec 25, 2023
d6c37b6
to constexprs
AtsushiYoshimura0302 Dec 26, 2023
56fa76d
Fix smaller n execution.
AtsushiYoshimura0302 Dec 26, 2023
0ba36d5
adaptive blocksize for counting
AtsushiYoshimura0302 Dec 26, 2023
d704f74
use const ref
AtsushiYoshimura0302 Dec 26, 2023
11671a4
remove define
AtsushiYoshimura0302 Dec 26, 2023
99cc300
fix compile error and remove unused comments
AtsushiYoshimura0302 Dec 26, 2023
1fd425c
remove macro
AtsushiYoshimura0302 Dec 26, 2023
c0ee24b
unified types
AtsushiYoshimura0302 Dec 26, 2023
8dbf83a
to constexpr noexcept
AtsushiYoshimura0302 Dec 26, 2023
be5f26e
use constexpr and remove unused functions
AtsushiYoshimura0302 Dec 26, 2023
a443768
use BIN_SIZE constant
AtsushiYoshimura0302 Dec 26, 2023
82d4aad
extract common process as extractDigit()
AtsushiYoshimura0302 Dec 26, 2023
866b70c
keyPair as a template parameter
AtsushiYoshimura0302 Dec 26, 2023
a9c4e61
remove unused codes
AtsushiYoshimura0302 Dec 26, 2023
7ae1709
delete unused inl
AtsushiYoshimura0302 Dec 26, 2023
6c9fc49
Add a special case handling, all elements have the same digit, to red…
AtsushiYoshimura0302 Dec 26, 2023
35b2654
Refactor indices
AtsushiYoshimura0302 Dec 26, 2023
7daad5c
implement counting part
AtsushiYoshimura0302 Dec 28, 2023
5043a03
slow but works
AtsushiYoshimura0302 Dec 28, 2023
2655782
Simplify
AtsushiYoshimura0302 Dec 28, 2023
d33d590
shared approach
AtsushiYoshimura0302 Dec 28, 2023
35a02f9
add explicit sync
AtsushiYoshimura0302 Dec 29, 2023
2179be6
larger block
AtsushiYoshimura0302 Dec 29, 2023
1269018
key cache
AtsushiYoshimura0302 Dec 29, 2023
bcb56c9
16bit lpsum
AtsushiYoshimura0302 Dec 29, 2023
1207fdc
16bit blockHist
AtsushiYoshimura0302 Dec 29, 2023
b78f5dc
keyValue support
AtsushiYoshimura0302 Dec 31, 2023
fd9357f
smaller warpOffsets
AtsushiYoshimura0302 Dec 31, 2023
1224e6d
n batch loading
AtsushiYoshimura0302 Dec 31, 2023
2c99707
warp level is fine
AtsushiYoshimura0302 Dec 31, 2023
59afd88
clean up
AtsushiYoshimura0302 Jan 2, 2024
fedd3c5
psum in gHistogram
AtsushiYoshimura0302 Jan 2, 2024
19dd9fa
remove unused
AtsushiYoshimura0302 Jan 2, 2024
ac0605d
refactor
AtsushiYoshimura0302 Jan 2, 2024
34ac365
fix undefined behavior and simplify
AtsushiYoshimura0302 Jan 3, 2024
3289c1b
simplify
AtsushiYoshimura0302 Jan 3, 2024
680a910
refactor
AtsushiYoshimura0302 Jan 3, 2024
d7d0274
support non blockDim != 256 case
AtsushiYoshimura0302 Jan 3, 2024
dea1411
remove unused
AtsushiYoshimura0302 Jan 3, 2024
c6f871b
reduce loops and ealier tail iterator is better
AtsushiYoshimura0302 Jan 3, 2024
402db80
remove redundant sync
AtsushiYoshimura0302 Jan 3, 2024
16e046c
use constant decl
AtsushiYoshimura0302 Jan 3, 2024
69b09b1
shorten
AtsushiYoshimura0302 Jan 14, 2024
2261c02
remove unused
AtsushiYoshimura0302 Jan 14, 2024
6723b8e
remove too much optimizations, fix potential sync issue etc
AtsushiYoshimura0302 Jan 18, 2024
1d399ef
remove unused branching. Thanks to ChihChen
AtsushiYoshimura0302 Feb 23, 2024
20b8e55
refactor the tail iterator conditions
AtsushiYoshimura0302 Feb 23, 2024
a1731fa
simple code is just fine at gHistogram. No more KEY_IS_16BYTE_ALIGNED
AtsushiYoshimura0302 Feb 23, 2024
d5beef7
remove unused
AtsushiYoshimura0302 Feb 24, 2024
1ff67d9
unify atomicInc
AtsushiYoshimura0302 Feb 24, 2024
b3af1e9
remove temporal splitmix64
AtsushiYoshimura0302 Mar 1, 2024
656e578
use arg_cast instead
AtsushiYoshimura0302 Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 84 additions & 144 deletions ParallelPrimitives/RadixSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <dlfcn.h>
#endif

constexpr uint64_t div_round_up64( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; }
constexpr uint64_t next_multiple64( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64( val, divisor ) * divisor; }

namespace
{
#if defined( ORO_PRECOMPILED )
Expand Down Expand Up @@ -76,23 +79,6 @@ RadixSort::RadixSort( oroDevice device, OrochiUtils& oroutils, oroStream stream,
configure( kernelPath, includeDir, stream );
}

void RadixSort::exclusiveScanCpu( const Oro::GpuMemory<int>& countsGpu, Oro::GpuMemory<int>& offsetsGpu ) const noexcept
{
const auto buffer_size = countsGpu.size();

std::vector<int> counts = countsGpu.getData();
std::vector<int> offsets( buffer_size );

int sum = 0;
for( int i = 0; i < counts.size(); ++i )
{
offsets[i] = sum;
sum += counts[i];
}

offsetsGpu.copyFromHost( offsets.data(), std::size( offsets ) );
}

void RadixSort::compileKernels( const std::string& kernelPath, const std::string& includeDir ) noexcept
{
static constexpr auto defaultKernelPath{ "../ParallelPrimitives/RadixSortKernels.h" };
Expand Down Expand Up @@ -124,203 +110,157 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
binaryPath = getCurrentDir();
binaryPath += isAmd ? "oro_compiled_kernels.hipfb" : "oro_compiled_kernels.fatbin";
log = "loading pre-compiled kernels at path : " + binaryPath;

m_num_threads_per_block_for_count = DEFAULT_COUNT_BLOCK_SIZE;
m_num_threads_per_block_for_scan = DEFAULT_SCAN_BLOCK_SIZE;
m_num_threads_per_block_for_sort = DEFAULT_SORT_BLOCK_SIZE;

m_warp_size = DEFAULT_WARP_SIZE;
}
else
{
log = "compiling kernels at path : " + currentKernelPath + " in : " + currentIncludeDir;

m_num_threads_per_block_for_count = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_COUNT_BLOCK_SIZE;
m_num_threads_per_block_for_scan = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SCAN_BLOCK_SIZE;
m_num_threads_per_block_for_sort = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SORT_BLOCK_SIZE;

m_warp_size = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE;

assert( m_num_threads_per_block_for_count % m_warp_size == 0 );
assert( m_num_threads_per_block_for_scan % m_warp_size == 0 );
assert( m_num_threads_per_block_for_sort % m_warp_size == 0 );
}

m_num_warps_per_block_for_sort = m_num_threads_per_block_for_sort / m_warp_size;

if( m_flags == Flag::LOG )
{
std::cout << log << std::endl;
}

const auto includeArg{ "-I" + currentIncludeDir };
std::vector<const char*> opts;
opts.push_back( includeArg.c_str() );

struct Record
{
std::string kernelName;
Kernel kernelType;
};

const std::vector<Record> records{
{ "CountKernel", Kernel::COUNT }, { "ParallelExclusiveScanSingleWG", Kernel::SCAN_SINGLE_WG }, { "ParallelExclusiveScanAllWG", Kernel::SCAN_PARALLEL }, { "SortKernel", Kernel::SORT },
{ "SortKVKernel", Kernel::SORT_KV }, { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
{ "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
};

const auto includeArg{ "-I" + currentIncludeDir };
const auto overwrite_flag = "-DOVERWRITE";
const auto count_block_size_param = "-DCOUNT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_count );
const auto scan_block_size_param = "-DSCAN_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_scan );
const auto sort_block_size_param = "-DSORT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_sort );
const auto sort_num_warps_param = "-DSORT_NUM_WARPS_PER_BLOCK_VAL=" + std::to_string( m_num_warps_per_block_for_sort );

std::vector<const char*> opts;

if( const std::string device_name = m_props.name; device_name.find( "NVIDIA" ) != std::string::npos )
{
opts.push_back( "--use_fast_math" );
}
else
{
opts.push_back( "-ffast-math" );
}

opts.push_back( includeArg.c_str() );
opts.push_back( overwrite_flag );
opts.push_back( count_block_size_param.c_str() );
opts.push_back( scan_block_size_param.c_str() );
opts.push_back( sort_block_size_param.c_str() );
opts.push_back( sort_num_warps_param.c_str() );

for( const auto& record : records )
{
if constexpr( useBakeKernel )
{
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromString( m_device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
}
else if constexpr( useBitCode )
#if defined( ORO_PP_LOAD_FROM_STRING )
oroFunctions[record.kernelType] = oroutils.getFunctionFromString( device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
#else

if constexpr( useBitCode )
{
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary( binaryPath.c_str(), record.kernelName.c_str() );
}
else
{
const auto includeArg{ "-I" + currentIncludeDir };
std::vector<const char*> opts;
opts.push_back( includeArg.c_str() );
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts );
}

#endif
if( m_flags == Flag::LOG )
{
printKernelInfo( record.kernelName, oroFunctions[record.kernelType] );
}
}
}

int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept
{
const int warpPerWG = blockSize / m_warp_size;
const int warpPerWGP = m_props.maxThreadsPerMultiProcessor / m_warp_size;
const int occupancyFromWarp = ( warpPerWGP > 0 ) ? ( warpPerWGP / warpPerWG ) : 1;

const int occupancy = std::max( 1, occupancyFromWarp );

if( m_flags == Flag::LOG )
{
std::cout << "Occupancy: " << occupancy << '\n';
}

static constexpr auto min_num_blocks = 16;
auto number_of_blocks = m_props.multiProcessorCount > 0 ? m_props.multiProcessorCount * occupancy : min_num_blocks;

if( m_num_threads_per_block_for_scan > BIN_SIZE )
{
// Note: both are divisible by 2
const auto base = m_num_threads_per_block_for_scan / BIN_SIZE;

// Floor
number_of_blocks = ( number_of_blocks / base ) * base;
}
// TODO: bit code support?
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what to do about this ... should we disable bitcode for now ?

#define LOAD_FUNC( var, kernel ) var = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), kernel, &opts );
LOAD_FUNC( m_gHistogram, "gHistogram" );
LOAD_FUNC( m_onesweep_reorderKey64, "onesweep_reorderKey64" );
LOAD_FUNC( m_onesweep_reorderKeyPair64, "onesweep_reorderKeyPair64" );
#undef LOAD_FUNC

return number_of_blocks;
}

void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
{
compileKernels( kernelPath, includeDir );

m_num_blocks_for_count = calculateWGsToExecute( m_num_threads_per_block_for_count );
u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * sizeof( u32 /* key type */ );
m_gpSumBuffer.resizeAsync( gpSumBuffer, false /*copy*/, stream );

/// The tmp buffer size of the count kernel and the scan kernel.
u64 lookBackBuffer = sizeof( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE );
m_lookbackBuffer.resizeAsync( lookBackBuffer, false /*copy*/, stream );

const auto tmp_buffer_size = BIN_SIZE * m_num_blocks_for_count;

/// @c tmp_buffer_size must be divisible by @c m_num_threads_per_block_for_scan
/// This is guaranteed since @c m_num_blocks_for_count will be adjusted accordingly

m_num_blocks_for_scan = tmp_buffer_size / m_num_threads_per_block_for_scan;

m_tmp_buffer.resizeAsync( tmp_buffer_size, false, stream );

if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
{
// These are for the scan kernel
m_partial_sum.resizeAsync( m_num_blocks_for_scan, false, stream );
m_is_ready.resizeAsync( m_num_blocks_for_scan, false, stream );
m_is_ready.resetAsync( stream );
}
m_tailIterator.resizeAsync( 1, false /*copy*/, stream );
m_tailIterator.resetAsync( stream );
m_gpSumCounter.resizeAsync( 1, false /*copy*/, stream );
}
void RadixSort::setFlag( Flag flag ) noexcept { m_flags = flag; }

void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream ) noexcept
void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept
{
bool keyPair = src.value != nullptr;

// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
// right now, setting this as large as possible is faster than multi pass sorting
if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
if( keyPair )
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
}
else
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
const void* args[] = { &src, &dst, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
}
return;
}

auto* s{ &src };
auto* d{ &dst };
int nIteration = div_round_up64( endBit - startBit, 8 );
uint64_t numberOfBlocks = div_round_up64( n, RADIX_SORT_BLOCK_SIZE );

for( int i = startBit; i < endBit; i += N_RADIX )
{
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );

std::swap( s, d );
}
m_lookbackBuffer.resetAsync( stream );
m_gpSumCounter.resetAsync( stream );
m_gpSumBuffer.resetAsync( stream );

if( s == &src )
// counter for gHistogram.
{
OrochiUtils::copyDtoDAsync( dst.key, src.key, n, stream );
OrochiUtils::copyDtoDAsync( dst.value, src.value, n, stream );
}
}
int maxBlocksPerMP = 0;
oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, m_gHistogram, GHISTOGRAM_THREADS_PER_BLOCK, 0 );
const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048;

void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int endBit, oroStream stream ) noexcept
{
// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
// right now, setting this as large as possible is faster than multi pass sorting
if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
const void* args[] = { &src, &dst, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
return;
const void* args[] = { &src.key, &n, arg_cast( m_gpSumBuffer.address() ), &startBit, arg_cast( m_gpSumCounter.address() ) };
OrochiUtils::launch1D( m_gHistogram, nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0, stream );
}

auto* s{ &src };
auto* d{ &dst };

for( int i = startBit; i < endBit; i += N_RADIX )
auto s = src;
auto d = dst;
for( int i = 0; i < nIteration; i++ )
{
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );
if( numberOfBlocks < LOOKBACK_TABLE_SIZE * 2 )
{
m_lookbackBuffer.resetAsync( stream );
} // other wise, we can skip zero clear look back buffer

if( keyPair )
{
const void* args[] = { &s.key, &d.key, &s.value, &d.value, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i };
OrochiUtils::launch1D( m_onesweep_reorderKeyPair64, numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream );
}
else
{
const void* args[] = { &s.key, &d.key, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i };
OrochiUtils::launch1D( m_onesweep_reorderKey64, numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream );
}
std::swap( s, d );
}

if( s == &src )
if( s.key == src.key )
{
OrochiUtils::copyDtoDAsync( dst, src, n, stream );
oroMemcpyDtoDAsync( (oroDeviceptr)dst.key, (oroDeviceptr)src.key, sizeof( uint32_t ) * n, stream );

if( keyPair )
{
oroMemcpyDtoDAsync( (oroDeviceptr)dst.value, (oroDeviceptr)src.value, sizeof( uint32_t ) * n, stream );
}
}
}

void RadixSort::sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept
{
sort( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream );
}
}; // namespace Oro
Loading