Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 93 additions & 165 deletions ParallelPrimitives/RadixSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,50 @@ namespace hip
{
static const char** RadixSortKernelsArgs = nullptr;
static const char** RadixSortKernelsIncludes = nullptr;
}
} // namespace hip
#endif

#if defined( __GNUC__ )
#include <dlfcn.h>
#endif

#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
#include <ParallelPrimitives/cache/oro_compiled_kernels.h> // generate this header with 'convert_binary_to_array.py'
#else
const unsigned char oro_compiled_kernels_h[] = "";
const size_t oro_compiled_kernels_h_size = 0;
#endif

constexpr uint64_t div_round_up64( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; }
constexpr uint64_t next_multiple64( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64( val, divisor ) * divisor; }

namespace
{

// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode.
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
// this flag means that we bake the precompiled kernels
constexpr auto usePrecompiledAndBakedKernel = true;
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )

// this flag means that we bake the precompiled kernels
constexpr auto usePrecompiledAndBakedKernel = true;

constexpr auto useBitCode = false;
constexpr auto useBakeKernel = false;
constexpr auto useBitCode = false;
constexpr auto useBakeKernel = false;

#else

constexpr auto usePrecompiledAndBakedKernel = false;
constexpr auto usePrecompiledAndBakedKernel = false;

#if defined( ORO_PRECOMPILED )
constexpr auto useBitCode = true; // this flag means we use the bitcode file
#else
constexpr auto useBitCode = false;
#endif
#if defined( ORO_PRECOMPILED )
constexpr auto useBitCode = true; // this flag means we use the bitcode file
#else
constexpr auto useBitCode = false;
#endif

#if defined( ORO_PP_LOAD_FROM_STRING )
constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string )
#else
constexpr auto useBakeKernel = false;
#endif
#if defined( ORO_PP_LOAD_FROM_STRING )
constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string )
#else
constexpr auto useBakeKernel = false;
#endif

#endif

Expand Down Expand Up @@ -124,23 +127,6 @@ RadixSort::RadixSort( oroDevice device, OrochiUtils& oroutils, oroStream stream,
configure( kernelPath, includeDir, stream );
}

void RadixSort::exclusiveScanCpu( const Oro::GpuMemory<int>& countsGpu, Oro::GpuMemory<int>& offsetsGpu ) const noexcept
{
const auto buffer_size = countsGpu.size();

std::vector<int> counts = countsGpu.getData();
std::vector<int> offsets( buffer_size );

int sum = 0;
for( int i = 0; i < counts.size(); ++i )
{
offsets[i] = sum;
sum += counts[i];
}

offsetsGpu.copyFromHost( offsets.data(), std::size( offsets ) );
}

void RadixSort::compileKernels( const std::string& kernelPath, const std::string& includeDir ) noexcept
{
static constexpr auto defaultKernelPath{ "../ParallelPrimitives/RadixSortKernels.h" };
Expand Down Expand Up @@ -172,77 +158,38 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
binaryPath = getCurrentDir();
binaryPath += isAmd ? "oro_compiled_kernels.hipfb" : "oro_compiled_kernels.fatbin";
log = "loading pre-compiled kernels at path : " + binaryPath;

m_num_threads_per_block_for_count = DEFAULT_COUNT_BLOCK_SIZE;
m_num_threads_per_block_for_scan = DEFAULT_SCAN_BLOCK_SIZE;
m_num_threads_per_block_for_sort = DEFAULT_SORT_BLOCK_SIZE;

m_warp_size = DEFAULT_WARP_SIZE;
}
else
{
log = "compiling kernels at path : " + currentKernelPath + " in : " + currentIncludeDir;

m_num_threads_per_block_for_count = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_COUNT_BLOCK_SIZE;
m_num_threads_per_block_for_scan = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SCAN_BLOCK_SIZE;
m_num_threads_per_block_for_sort = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SORT_BLOCK_SIZE;

m_warp_size = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE;

assert( m_num_threads_per_block_for_count % m_warp_size == 0 );
assert( m_num_threads_per_block_for_scan % m_warp_size == 0 );
assert( m_num_threads_per_block_for_sort % m_warp_size == 0 );
}

m_num_warps_per_block_for_sort = m_num_threads_per_block_for_sort / m_warp_size;

if( m_flags == Flag::LOG )
{
std::cout << log << std::endl;
}

const auto includeArg{ "-I" + currentIncludeDir };
std::vector<const char*> opts;
opts.push_back( includeArg.c_str() );

struct Record
{
std::string kernelName;
Kernel kernelType;
};

const std::vector<Record> records{
{ "CountKernel", Kernel::COUNT }, { "ParallelExclusiveScanSingleWG", Kernel::SCAN_SINGLE_WG }, { "ParallelExclusiveScanAllWG", Kernel::SCAN_PARALLEL }, { "SortKernel", Kernel::SORT },
{ "SortKVKernel", Kernel::SORT_KV }, { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
};

const auto includeArg{ "-I" + currentIncludeDir };
const auto overwrite_flag = "-DOVERWRITE";
const auto count_block_size_param = "-DCOUNT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_count );
const auto scan_block_size_param = "-DSCAN_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_scan );
const auto sort_block_size_param = "-DSORT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_sort );
const auto sort_num_warps_param = "-DSORT_NUM_WARPS_PER_BLOCK_VAL=" + std::to_string( m_num_warps_per_block_for_sort );

std::vector<const char*> opts;

if( const std::string device_name = m_props.name; device_name.find( "NVIDIA" ) != std::string::npos )
{
opts.push_back( "--use_fast_math" );
}
else
{
opts.push_back( "-ffast-math" );
}

opts.push_back( includeArg.c_str() );
opts.push_back( overwrite_flag );
opts.push_back( count_block_size_param.c_str() );
opts.push_back( scan_block_size_param.c_str() );
opts.push_back( sort_block_size_param.c_str() );
opts.push_back( sort_num_warps_param.c_str() );

const std::vector<Record> records{ { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS },
{ "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
{ "GHistogram", Kernel::SORT_GHISTOGRAM },
{ "OnesweepReorderKey64", Kernel::SORT_ONESWEEP_REORDER_KEY_64 },
{ "OnesweepReorderKeyPair64", Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64 } };

for( const auto& record : records )
{
if constexpr( usePrecompiledAndBakedKernel )
{
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() );
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() );
}
else if constexpr( useBakeKernel )
{
Expand All @@ -262,120 +209,101 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
printKernelInfo( record.kernelName, oroFunctions[record.kernelType] );
}
}

return;
}

int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept
{
const int warpPerWG = blockSize / m_warp_size;
const int warpPerWGP = m_props.maxThreadsPerMultiProcessor / m_warp_size;
const int occupancyFromWarp = ( warpPerWGP > 0 ) ? ( warpPerWGP / warpPerWG ) : 1;

const int occupancy = std::max( 1, occupancyFromWarp );

if( m_flags == Flag::LOG )
{
std::cout << "Occupancy: " << occupancy << '\n';
}

static constexpr auto min_num_blocks = 16;
auto number_of_blocks = m_props.multiProcessorCount > 0 ? m_props.multiProcessorCount * occupancy : min_num_blocks;

if( m_num_threads_per_block_for_scan > BIN_SIZE )
{
// Note: both are divisible by 2
const auto base = m_num_threads_per_block_for_scan / BIN_SIZE;

// Floor
number_of_blocks = ( number_of_blocks / base ) * base;
}

return number_of_blocks;
}

void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
{
compileKernels( kernelPath, includeDir );

m_num_blocks_for_count = calculateWGsToExecute( m_num_threads_per_block_for_count );
constexpr bool enable_copying = false;
constexpr auto key_type_size = sizeof( std::remove_pointer_t<decltype( KeyValueSoA::key )> );

/// The tmp buffer size of the count kernel and the scan kernel.
constexpr u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * key_type_size;
m_gpSumBuffer.resizeAsync( gpSumBuffer, enable_copying /*copy*/, stream );

const auto tmp_buffer_size = BIN_SIZE * m_num_blocks_for_count;
u64 lookBackBuffer = sizeof( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE );
m_lookbackBuffer.resizeAsync( lookBackBuffer, enable_copying /*copy*/, stream );

/// @c tmp_buffer_size must be divisible by @c m_num_threads_per_block_for_scan
/// This is guaranteed since @c m_num_blocks_for_count will be adjusted accordingly

m_num_blocks_for_scan = tmp_buffer_size / m_num_threads_per_block_for_scan;

m_tmp_buffer.resizeAsync( tmp_buffer_size, false, stream );

if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
{
// These are for the scan kernel
m_partial_sum.resizeAsync( m_num_blocks_for_scan, false, stream );
m_is_ready.resizeAsync( m_num_blocks_for_scan, false, stream );
m_is_ready.resetAsync( stream );
}
m_tailIterator.resizeAsync( 1, enable_copying /*copy*/, stream );
m_tailIterator.resetAsync( stream );
m_gpSumCounter.resizeAsync( 1, enable_copying /*copy*/, stream );
}
void RadixSort::setFlag( Flag flag ) noexcept { m_flags = flag; }

void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream ) noexcept
void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept
{
bool keyPair = src.value != nullptr;

// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
// right now, setting this as large as possible is faster than multi pass sorting
if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
if( keyPair )
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
}
else
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
const void* args[] = { &src, &dst, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
}
return;
}

auto* s{ &src };
auto* d{ &dst };
constexpr uint64_t bit_per_iteration = 8ULL;

for( int i = startBit; i < endBit; i += N_RADIX )
{
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );
int nIteration = div_round_up64( endBit - startBit, bit_per_iteration );
uint64_t numberOfBlocks = div_round_up64( n, RADIX_SORT_BLOCK_SIZE );

std::swap( s, d );
}
m_lookbackBuffer.resetAsync( stream );
m_gpSumCounter.resetAsync( stream );
m_gpSumBuffer.resetAsync( stream );

if( s == &src )
// counter for gHistogram.
{
OrochiUtils::copyDtoDAsync( dst.key, src.key, n, stream );
OrochiUtils::copyDtoDAsync( dst.value, src.value, n, stream );
}
}
int maxBlocksPerMP = 0;
oroError e = oroModuleOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 );
const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048;

void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int endBit, oroStream stream ) noexcept
{
// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
// right now, setting this as large as possible is faster than multi pass sorting
if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
{
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
const void* args[] = { &src, &dst, &n, &startBit, &endBit };
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
return;
const void* args[] = { &src.key, &n, arg_cast( m_gpSumBuffer.address() ), &startBit, arg_cast( m_gpSumCounter.address() ) };
OrochiUtils::launch1D( oroFunctions[Kernel::SORT_GHISTOGRAM], nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0, stream );
}

auto* s{ &src };
auto* d{ &dst };

for( int i = startBit; i < endBit; i += N_RADIX )
auto s = src;
auto d = dst;
for( int i = 0; i < nIteration; ++i )
{
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );
if( numberOfBlocks < LOOKBACK_TABLE_SIZE * 2 )
{
m_lookbackBuffer.resetAsync( stream );
} // other wise, we can skip zero clear look back buffer

if( keyPair )
{
const void* args[] = { &s.key, &d.key, &s.value, &d.value, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i };
OrochiUtils::launch1D( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream );
}
else
{
const void* args[] = { &s.key, &d.key, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i };
OrochiUtils::launch1D( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream );
}
std::swap( s, d );
}

if( s == &src )
if( s.key == src.key )
{
OrochiUtils::copyDtoDAsync( dst, src, n, stream );
m_oroutils.copyDtoDAsync( dst.key, src.key, n, stream );

if( keyPair )
{
m_oroutils.copyDtoDAsync( dst.value, src.value, n, stream );
}
}
}

void RadixSort::sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept { sort( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream ); }
}; // namespace Oro
Loading