From f515ce13c9513788444c27a94b1e0b093d1bd00b Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Wed, 19 Feb 2025 09:16:35 +0900 Subject: [PATCH 01/11] Porting code from the old repo --- ParallelPrimitives/RadixSort.cpp | 313 ++++------ ParallelPrimitives/RadixSort.h | 87 +-- ParallelPrimitives/RadixSort.inl | 163 ------ ParallelPrimitives/RadixSortConfigs.h | 64 +- ParallelPrimitives/RadixSortKernels.h | 801 +++++++++++--------------- Test/RadixSort/main.cpp | 17 +- 6 files changed, 477 insertions(+), 968 deletions(-) delete mode 100644 ParallelPrimitives/RadixSort.inl diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index b12d7689..62c74159 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -1,25 +1,3 @@ -// -// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. -// - #include #include #include @@ -27,63 +5,40 @@ #include #include -// if ORO_PP_LOAD_FROM_STRING && ORO_PRECOMPILED -> we load the precompiled/baked kernels. -// if ORO_PP_LOAD_FROM_STRING && NOT ORO_PRECOMPILED -> we load the baked source code kernels (from Kernels.h / KernelArgs.h) -#if !defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) +#if defined( ORO_PP_LOAD_FROM_STRING ) + // Note: the include order must be in this particular form. // clang-format off #include #include // clang-format on -#else -// if Kernels.h / KernelArgs.h are not included, declare nullptr strings -static const char* hip_RadixSortKernels = nullptr; -namespace hip -{ -static const char** RadixSortKernelsArgs = nullptr; -static const char** RadixSortKernelsIncludes = nullptr; -} #endif #if defined( __GNUC__ ) #include #endif -#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) -#include // generate this header with 'convert_binary_to_array.py' -#else -const unsigned char oro_compiled_kernels_h[] = ""; -const size_t oro_compiled_kernels_h_size = 0; -#endif +constexpr uint64_t div_round_up64( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; } +constexpr uint64_t next_multiple64( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64( val, divisor ) * divisor; } namespace { - -// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode. -#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) - - // this flag means that we bake the precompiled kernels - constexpr auto usePrecompiledAndBakedKernel = true; - - constexpr auto useBitCode = false; - constexpr auto useBakeKernel = false; - +#if defined( ORO_PRECOMPILED ) +constexpr auto useBitCode = true; #else +constexpr auto useBitCode = false; +#endif - constexpr auto usePrecompiledAndBakedKernel = false; - - #if defined( ORO_PRECOMPILED ) - constexpr auto useBitCode = true; // this flag means we use the bitcode file - #else - constexpr auto useBitCode = false; - #endif - - #if defined( ORO_PP_LOAD_FROM_STRING ) - constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string ) - #else - constexpr auto useBakeKernel = false; - #endif - +#if defined( ORO_PP_LOAD_FROM_STRING ) +constexpr auto useBakeKernel = true; +#else +constexpr auto useBakeKernel = false; +static const char* hip_RadixSortKernels = nullptr; +namespace hip +{ +static const char** RadixSortKernelsArgs = nullptr; +static const char** RadixSortKernelsIncludes = nullptr; +} // namespace hip #endif static_assert( !( useBitCode && useBakeKernel ), "useBitCode and useBakeKernel cannot coexist" ); @@ -124,23 +79,6 @@ RadixSort::RadixSort( oroDevice device, OrochiUtils& oroutils, oroStream stream, configure( kernelPath, includeDir, stream ); } -void RadixSort::exclusiveScanCpu( const Oro::GpuMemory& countsGpu, Oro::GpuMemory& offsetsGpu ) const noexcept -{ - const auto buffer_size = countsGpu.size(); - - std::vector counts = countsGpu.getData(); - std::vector offsets( buffer_size ); - - int sum = 0; - for( int i = 0; i < counts.size(); ++i ) - { - offsets[i] = sum; - sum += counts[i]; - } - - offsetsGpu.copyFromHost( offsets.data(), std::size( offsets ) ); -} - void RadixSort::compileKernels( const std::string& kernelPath, const std::string& includeDir ) noexcept { static constexpr auto defaultKernelPath{ "../ParallelPrimitives/RadixSortKernels.h" }; @@ -172,35 +110,21 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string binaryPath = getCurrentDir(); binaryPath += isAmd ? "oro_compiled_kernels.hipfb" : "oro_compiled_kernels.fatbin"; log = "loading pre-compiled kernels at path : " + binaryPath; - - m_num_threads_per_block_for_count = DEFAULT_COUNT_BLOCK_SIZE; - m_num_threads_per_block_for_scan = DEFAULT_SCAN_BLOCK_SIZE; - m_num_threads_per_block_for_sort = DEFAULT_SORT_BLOCK_SIZE; - - m_warp_size = DEFAULT_WARP_SIZE; } else { log = "compiling kernels at path : " + currentKernelPath + " in : " + currentIncludeDir; - - m_num_threads_per_block_for_count = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_COUNT_BLOCK_SIZE; - m_num_threads_per_block_for_scan = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SCAN_BLOCK_SIZE; - m_num_threads_per_block_for_sort = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SORT_BLOCK_SIZE; - - m_warp_size = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE; - - assert( m_num_threads_per_block_for_count % m_warp_size == 0 ); - assert( m_num_threads_per_block_for_scan % m_warp_size == 0 ); - assert( m_num_threads_per_block_for_sort % m_warp_size == 0 ); } - m_num_warps_per_block_for_sort = m_num_threads_per_block_for_sort / m_warp_size; - if( m_flags == Flag::LOG ) { std::cout << log << std::endl; } + const auto includeArg{ "-I" + currentIncludeDir }; + std::vector opts; + opts.push_back( includeArg.c_str() ); + struct Record { std::string kernelName; @@ -208,174 +132,135 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string }; const std::vector records{ - { "CountKernel", Kernel::COUNT }, { "ParallelExclusiveScanSingleWG", Kernel::SCAN_SINGLE_WG }, { "ParallelExclusiveScanAllWG", Kernel::SCAN_PARALLEL }, { "SortKernel", Kernel::SORT }, - { "SortKVKernel", Kernel::SORT_KV }, { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV }, + { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV }, }; - const auto includeArg{ "-I" + currentIncludeDir }; - const auto overwrite_flag = "-DOVERWRITE"; - const auto count_block_size_param = "-DCOUNT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_count ); - const auto scan_block_size_param = "-DSCAN_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_scan ); - const auto sort_block_size_param = "-DSORT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_sort ); - const auto sort_num_warps_param = "-DSORT_NUM_WARPS_PER_BLOCK_VAL=" + std::to_string( m_num_warps_per_block_for_sort ); - - std::vector opts; - - if( const std::string device_name = m_props.name; device_name.find( "NVIDIA" ) != std::string::npos ) - { - opts.push_back( "--use_fast_math" ); - } - else - { - opts.push_back( "-ffast-math" ); - } - - opts.push_back( includeArg.c_str() ); - opts.push_back( overwrite_flag ); - opts.push_back( count_block_size_param.c_str() ); - opts.push_back( scan_block_size_param.c_str() ); - opts.push_back( sort_block_size_param.c_str() ); - opts.push_back( sort_num_warps_param.c_str() ); - for( const auto& record : records ) { - if constexpr( usePrecompiledAndBakedKernel ) - { - oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() ); - } - else if constexpr( useBakeKernel ) - { - oroFunctions[record.kernelType] = m_oroutils.getFunctionFromString( m_device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes ); - } - else if constexpr( useBitCode ) +#if defined( ORO_PP_LOAD_FROM_STRING ) + oroFunctions[record.kernelType] = oroutils.getFunctionFromString( device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes ); +#else + + if constexpr( useBitCode ) { oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary( binaryPath.c_str(), record.kernelName.c_str() ); } else { + const auto includeArg{ "-I" + currentIncludeDir }; + std::vector opts; + opts.push_back( includeArg.c_str() ); oroFunctions[record.kernelType] = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts ); } +#endif if( m_flags == Flag::LOG ) { printKernelInfo( record.kernelName, oroFunctions[record.kernelType] ); } } - return; -} - -int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept -{ - const int warpPerWG = blockSize / m_warp_size; - const int warpPerWGP = m_props.maxThreadsPerMultiProcessor / m_warp_size; - const int occupancyFromWarp = ( warpPerWGP > 0 ) ? ( warpPerWGP / warpPerWG ) : 1; - - const int occupancy = std::max( 1, occupancyFromWarp ); - - if( m_flags == Flag::LOG ) - { - std::cout << "Occupancy: " << occupancy << '\n'; - } - - static constexpr auto min_num_blocks = 16; - auto number_of_blocks = m_props.multiProcessorCount > 0 ? m_props.multiProcessorCount * occupancy : min_num_blocks; - - if( m_num_threads_per_block_for_scan > BIN_SIZE ) - { - // Note: both are divisible by 2 - const auto base = m_num_threads_per_block_for_scan / BIN_SIZE; + // TODO: bit code support? +#define LOAD_FUNC( var, kernel ) var = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), kernel, &opts ); + LOAD_FUNC( m_gHistogram, "gHistogram" ); + LOAD_FUNC( m_onesweep_reorderKey64, "onesweep_reorderKey64" ); + LOAD_FUNC( m_onesweep_reorderKeyPair64, "onesweep_reorderKeyPair64" ); +#undef LOAD_FUNC - // Floor - number_of_blocks = ( number_of_blocks / base ) * base; - } - - return number_of_blocks; } void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept { compileKernels( kernelPath, includeDir ); - m_num_blocks_for_count = calculateWGsToExecute( m_num_threads_per_block_for_count ); - - /// The tmp buffer size of the count kernel and the scan kernel. - - const auto tmp_buffer_size = BIN_SIZE * m_num_blocks_for_count; + u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * sizeof( u32 /* key type */ ); + m_gpSumBuffer.resizeAsync( gpSumBuffer, false /*copy*/, stream ); - /// @c tmp_buffer_size must be divisible by @c m_num_threads_per_block_for_scan - /// This is guaranteed since @c m_num_blocks_for_count will be adjusted accordingly + u64 lookBackBuffer = sizeof( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE ); + m_lookbackBuffer.resizeAsync( lookBackBuffer, false /*copy*/, stream ); - m_num_blocks_for_scan = tmp_buffer_size / m_num_threads_per_block_for_scan; - - m_tmp_buffer.resizeAsync( tmp_buffer_size, false, stream ); - - if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL ) - { - // These are for the scan kernel - m_partial_sum.resizeAsync( m_num_blocks_for_scan, false, stream ); - m_is_ready.resizeAsync( m_num_blocks_for_scan, false, stream ); - m_is_ready.resetAsync( stream ); - } + m_tailIterator.resizeAsync( 1, false /*copy*/, stream ); + m_tailIterator.resetAsync( stream ); + m_gpSumCounter.resizeAsync( 1, false /*copy*/, stream ); } void RadixSort::setFlag( Flag flag ) noexcept { m_flags = flag; } -void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream ) noexcept +void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept { + bool keyPair = src.value != nullptr; + // todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly // right now, setting this as large as possible is faster than multi pass sorting if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI ) { - const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV]; - const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit }; - OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream ); + if( keyPair ) + { + const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV]; + const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit }; + OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream ); + } + else + { + const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS]; + const void* args[] = { &src, &dst, &n, &startBit, &endBit }; + OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream ); + } return; } - auto* s{ &src }; - auto* d{ &dst }; + int nIteration = div_round_up64( endBit - startBit, 8 ); + uint64_t numberOfBlocks = div_round_up64( n, RADIX_SORT_BLOCK_SIZE ); - for( int i = startBit; i < endBit; i += N_RADIX ) - { - sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream ); - - std::swap( s, d ); - } + m_lookbackBuffer.resetAsync( stream ); + m_gpSumCounter.resetAsync( stream ); + m_gpSumBuffer.resetAsync( stream ); - if( s == &src ) + // counter for gHistogram. { - OrochiUtils::copyDtoDAsync( dst.key, src.key, n, stream ); - OrochiUtils::copyDtoDAsync( dst.value, src.value, n, stream ); - } -} + int maxBlocksPerMP = 0; + oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, m_gHistogram, GHISTOGRAM_THREADS_PER_BLOCK, 0 ); + const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048; -void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int endBit, oroStream stream ) noexcept -{ - // todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly - // right now, setting this as large as possible is faster than multi pass sorting - if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI ) - { - const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS]; - const void* args[] = { &src, &dst, &n, &startBit, &endBit }; - OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream ); - return; + const void* args[] = { &src.key, &n, arg_cast( m_gpSumBuffer.address() ), &startBit, arg_cast( m_gpSumCounter.address() ) }; + OrochiUtils::launch1D( m_gHistogram, nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0, stream ); } - auto* s{ &src }; - auto* d{ &dst }; - - for( int i = startBit; i < endBit; i += N_RADIX ) + auto s = src; + auto d = dst; + for( int i = 0; i < nIteration; i++ ) { - sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream ); + if( numberOfBlocks < LOOKBACK_TABLE_SIZE * 2 ) + { + m_lookbackBuffer.resetAsync( stream ); + } // other wise, we can skip zero clear look back buffer + if( keyPair ) + { + const void* args[] = { &s.key, &d.key, &s.value, &d.value, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i }; + OrochiUtils::launch1D( m_onesweep_reorderKeyPair64, numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream ); + } + else + { + const void* args[] = { &s.key, &d.key, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i }; + OrochiUtils::launch1D( m_onesweep_reorderKey64, numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream ); + } std::swap( s, d ); } - if( s == &src ) + if( s.key == src.key ) { - OrochiUtils::copyDtoDAsync( dst, src, n, stream ); + oroMemcpyDtoDAsync( (oroDeviceptr)dst.key, (oroDeviceptr)src.key, sizeof( uint32_t ) * n, stream ); + + if( keyPair ) + { + oroMemcpyDtoDAsync( (oroDeviceptr)dst.value, (oroDeviceptr)src.value, sizeof( uint32_t ) * n, stream ); + } } } +void RadixSort::sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept +{ + sort( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream ); +} }; // namespace Oro diff --git a/ParallelPrimitives/RadixSort.h b/ParallelPrimitives/RadixSort.h index e22ee35f..a20c9f0c 100644 --- a/ParallelPrimitives/RadixSort.h +++ b/ParallelPrimitives/RadixSort.h @@ -1,26 +1,3 @@ -// -// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. -// - - #pragma once #include @@ -66,84 +43,44 @@ class RadixSort final void setFlag( Flag flag ) noexcept; - void sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream = 0 ) noexcept; + void sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n, int startBit, int endBit, oroStream stream = 0 ) noexcept; - void sort( const u32* src, const u32* dst, int n, int startBit, int endBit, oroStream stream = 0 ) noexcept; + void sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream = 0 ) noexcept; private: - template - void sort1pass( const T src, const T dst, int n, int startBit, int endBit, oroStream stream ) noexcept; - - /// @brief Compile the kernels for radix sort. - /// @param kernelPath The kernel path. - /// @param includeDir The include directory. + // @brief Compile the kernels for radix sort. + // @param kernelPath The kernel path. + // @param includeDir The include directory. void compileKernels( const std::string& kernelPath, const std::string& includeDir ) noexcept; - [[nodiscard]] int calculateWGsToExecute( const int blockSize ) const noexcept; - - /// @brief Exclusive scan algorithm on CPU for testing. - /// It copies the count result from the Device to Host before computation, and then copies the offsets back from Host to Device afterward. - /// @param countsGpu The count result in GPU memory. Otuput: The offset. - /// @param offsetsGpu The offsets. - void exclusiveScanCpu( const Oro::GpuMemory& countsGpu, Oro::GpuMemory& offsetsGpu ) const noexcept; - /// @brief Configure the settings, compile the kernels and allocate the memory. /// @param kernelPath The kernel path. /// @param includeDir The include directory. void configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept; private: - // GPU blocks for the count kernel - int m_num_blocks_for_count{}; - - // GPU blocks for the scan kernel - int m_num_blocks_for_scan{}; - Flag m_flags{ Flag::NO_LOG }; enum class Kernel { - COUNT, - SCAN_SINGLE_WG, - SCAN_PARALLEL, - SORT, - SORT_KV, SORT_SINGLE_PASS, SORT_SINGLE_PASS_KV, }; std::unordered_map oroFunctions; - /// @brief The enum class which indicates the selected algorithm of prefix scan. - enum class ScanAlgo - { - SCAN_CPU, - SCAN_GPU_SINGLE_WG, - SCAN_GPU_PARALLEL, - }; - - constexpr static auto selectedScanAlgo{ ScanAlgo::SCAN_GPU_PARALLEL }; - - GpuMemory m_partial_sum; - GpuMemory m_is_ready; - oroDevice m_device{}; oroDeviceProp m_props{}; OrochiUtils& m_oroutils; - // This buffer holds the "bucket" table from all GPU blocks. - GpuMemory m_tmp_buffer; + oroFunction m_gHistogram; + oroFunction m_onesweep_reorderKey64; + oroFunction m_onesweep_reorderKeyPair64; - int m_num_threads_per_block_for_count{}; - int m_num_threads_per_block_for_scan{}; - int m_num_threads_per_block_for_sort{}; - - int m_num_warps_per_block_for_sort{}; - - int m_warp_size{}; + GpuMemory m_lookbackBuffer; + GpuMemory m_gpSumBuffer; + GpuMemory m_gpSumCounter; + GpuMemory m_tailIterator; }; - -#include - }; // namespace Oro diff --git a/ParallelPrimitives/RadixSort.inl b/ParallelPrimitives/RadixSort.inl deleted file mode 100644 index ad9ecdb5..00000000 --- a/ParallelPrimitives/RadixSort.inl +++ /dev/null @@ -1,163 +0,0 @@ - - -namespace -{ - -struct Empty -{ -}; - -/// @brief Call the callable and measure the elapsed time using the Stopwatch. -/// @tparam CallableType The type of the callable to be invoked in this function. -/// @tparam RecordType The type of the object that stores the recorded times. -/// @tparam enable_profile The elapsed time will be recorded if this is set to True. -/// @param callable The callable object to be called. -/// @param time_record The object that stores the recorded times. -/// @param index The index indicates where to store the elapsed time in @c time_record -/// @param stream The GPU stream -template -constexpr void execute( CallableType&& callable, RecordType& time_record, const int index, const oroStream stream ) noexcept -{ - using TimerType = std::conditional_t; - - TimerType stopwatch; - - if constexpr( enable_profile ) - { - stopwatch.start(); - } - - std::invoke( std::forward( callable ) ); - - if constexpr( enable_profile ) - { - OrochiUtils::waitForCompletion( stream ); - stopwatch.stop(); - time_record[index] = stopwatch.getMs(); - } -} - -template -void resize_record( T& t ) noexcept -{ - if constexpr( enable_profile ) - { - t.resize( 3 ); - } -} - -template -void print_record( const T& t ) noexcept -{ - if constexpr( enable_profile ) - { - printf( "%3.2f, %3.2f, %3.2f\n", t[0], t[1], t[2] ); - } -} - -} // namespace - -template -void RadixSort::sort1pass( const T src, const T dst, int n, int startBit, int endBit, oroStream stream ) noexcept -{ - static constexpr auto enable_profile = false; - - const u32* srcKey{ nullptr }; - const u32* dstKey{ nullptr }; - - const u32* srcVal{ nullptr }; - const u32* dstVal{ nullptr }; - - static constexpr auto enable_key_value_pair_sorting{ std::is_same_v }; - - if constexpr( enable_key_value_pair_sorting ) - { - srcKey = src.key; - dstKey = dst.key; - - srcVal = src.value; - dstVal = dst.value; - } - else - { - static_assert( std::is_same_v || std::is_same_v ); - srcKey = src; - dstKey = dst; - } - - const int nItemPerWG = ( n + m_num_blocks_for_count - 1 ) / m_num_blocks_for_count; - - // Timer records - - using RecordType = std::conditional_t, Empty>; - RecordType t; - - resize_record( t ); - - const auto launch_count_kernel = [&]() noexcept - { - const auto num_total_thread_for_count = m_num_threads_per_block_for_count * m_num_blocks_for_count; - - const auto func{ oroFunctions[Kernel::COUNT] }; - const void* args[] = { &srcKey, arg_cast( m_tmp_buffer.address() ), &n, &nItemPerWG, &startBit, &m_num_blocks_for_count }; - OrochiUtils::launch1D( func, num_total_thread_for_count, args, m_num_threads_per_block_for_count, 0, stream ); - }; - - execute( launch_count_kernel, t, 0, stream ); - - const auto launch_scan_kernel = [&]() noexcept - { - switch( selectedScanAlgo ) - { - case ScanAlgo::SCAN_CPU: - { - exclusiveScanCpu( m_tmp_buffer, m_tmp_buffer ); - } - break; - - case ScanAlgo::SCAN_GPU_SINGLE_WG: - { - const void* args[] = { arg_cast( m_tmp_buffer.address() ), arg_cast( m_tmp_buffer.address() ), &m_num_blocks_for_count }; - OrochiUtils::launch1D( oroFunctions[Kernel::SCAN_SINGLE_WG], WG_SIZE * m_num_blocks_for_count, args, WG_SIZE, 0, stream ); - } - break; - - case ScanAlgo::SCAN_GPU_PARALLEL: - { - const auto num_total_thread_for_scan = m_num_threads_per_block_for_scan * m_num_blocks_for_scan; - - const void* args[] = { arg_cast( m_tmp_buffer.address() ), arg_cast( m_tmp_buffer.address() ), arg_cast( m_partial_sum.address() ), arg_cast( m_is_ready.address() ) }; - OrochiUtils::launch1D( oroFunctions[Kernel::SCAN_PARALLEL], num_total_thread_for_scan, args, m_num_threads_per_block_for_scan, 0, stream ); - } - break; - - default: - exclusiveScanCpu( m_tmp_buffer, m_tmp_buffer ); - break; - } - }; - - execute( launch_scan_kernel, t, 1, stream ); - - const auto launch_sort_kernel = [&]() noexcept - { - const auto num_blocks_for_sort = m_num_blocks_for_count; - const auto num_total_thread_for_sort = m_num_threads_per_block_for_sort * num_blocks_for_sort; - const auto num_items_per_block = nItemPerWG; - - if constexpr( enable_key_value_pair_sorting ) - { - const void* args[] = { &srcKey, &srcVal, &dstKey, &dstVal, arg_cast( m_tmp_buffer.address() ), &n, &num_items_per_block, &startBit, &num_blocks_for_sort }; - OrochiUtils::launch1D( oroFunctions[Kernel::SORT_KV], num_total_thread_for_sort, args, m_num_threads_per_block_for_sort, 0, stream ); - } - else - { - const void* args[] = { &srcKey, &dstKey, arg_cast( m_tmp_buffer.address() ), &n, &num_items_per_block, &startBit, &num_blocks_for_sort }; - OrochiUtils::launch1D( oroFunctions[Kernel::SORT], num_total_thread_for_sort, args, m_num_threads_per_block_for_sort, 0, stream ); - } - }; - - execute( launch_sort_kernel, t, 2, stream ); - - print_record( t ); -} diff --git a/ParallelPrimitives/RadixSortConfigs.h b/ParallelPrimitives/RadixSortConfigs.h index 6857254a..40cd1120 100644 --- a/ParallelPrimitives/RadixSortConfigs.h +++ b/ParallelPrimitives/RadixSortConfigs.h @@ -1,25 +1,3 @@ -// -// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. -// - #pragma once namespace Oro @@ -29,29 +7,11 @@ constexpr auto N_RADIX{ 8 }; constexpr auto BIN_SIZE{ 1 << N_RADIX }; constexpr auto RADIX_MASK{ ( 1 << N_RADIX ) - 1 }; constexpr auto PACK_FACTOR{ sizeof( int ) / sizeof( char ) }; -constexpr auto N_PACKED{ BIN_SIZE / PACK_FACTOR }; -constexpr auto PACK_MAX{ 255 }; -constexpr auto N_PACKED_PER_WI{ N_PACKED / WG_SIZE }; -constexpr auto N_BINS_PER_WI{ BIN_SIZE / WG_SIZE }; constexpr auto N_BINS_4BIT{ 16 }; constexpr auto N_BINS_PACK_FACTOR{ sizeof( long long ) / sizeof( short ) }; constexpr auto N_BINS_PACKED_4BIT{ N_BINS_4BIT / N_BINS_PACK_FACTOR }; -constexpr auto N_BINS_8BIT{ 1 << 8 }; - -constexpr auto DEFAULT_WARP_SIZE{ 32 }; - -constexpr auto DEFAULT_NUM_WARPS_PER_BLOCK{ 8 }; - -// count config - -constexpr auto DEFAULT_COUNT_BLOCK_SIZE{ DEFAULT_WARP_SIZE * DEFAULT_NUM_WARPS_PER_BLOCK }; - -// scan configs -constexpr auto DEFAULT_SCAN_BLOCK_SIZE{ DEFAULT_WARP_SIZE * DEFAULT_NUM_WARPS_PER_BLOCK }; - // sort configs -constexpr auto DEFAULT_SORT_BLOCK_SIZE{ DEFAULT_WARP_SIZE * DEFAULT_NUM_WARPS_PER_BLOCK }; constexpr auto SORT_N_ITEMS_PER_WI{ 12 }; constexpr auto SINGLE_SORT_N_ITEMS_PER_WI{ 24 }; constexpr auto SINGLE_SORT_WG_SIZE{ 128 }; @@ -60,8 +20,26 @@ constexpr auto SINGLE_SORT_WG_SIZE{ 128 }; static_assert( BIN_SIZE % 2 == 0 ); -// Notice that, on some GPUs, the max size of a GPU block cannot be greater than 256 -static_assert( DEFAULT_COUNT_BLOCK_SIZE % DEFAULT_WARP_SIZE == 0 ); -static_assert( DEFAULT_SCAN_BLOCK_SIZE % DEFAULT_WARP_SIZE == 0 ); +constexpr int WARP_SIZE = 32; + +constexpr int RADIX_SORT_BLOCK_SIZE = 4096; + +constexpr int GHISTOGRAM_ITEM_PER_BLOCK = 2048; +constexpr int GHISTOGRAM_THREADS_PER_BLOCK = 256; +constexpr int GHISTOGRAM_ITEMS_PER_THREAD = GHISTOGRAM_ITEM_PER_BLOCK / GHISTOGRAM_THREADS_PER_BLOCK; + +constexpr int REORDER_NUMBER_OF_WARPS = 8; +constexpr int REORDER_NUMBER_OF_THREADS_PER_BLOCK = WARP_SIZE * REORDER_NUMBER_OF_WARPS; +constexpr int REORDER_NUMBER_OF_ITEM_PER_WARP = RADIX_SORT_BLOCK_SIZE / REORDER_NUMBER_OF_WARPS; +constexpr int REORDER_NUMBER_OF_ITEM_PER_THREAD = REORDER_NUMBER_OF_ITEM_PER_WARP / 32; + +constexpr int LOOKBACK_TABLE_SIZE = 1024; +constexpr int MAX_LOOK_BACK = 64; +constexpr int TAIL_BITS = 5; +constexpr auto TAIL_MASK = 0xFFFFFFFFu << TAIL_BITS; +static_assert( MAX_LOOK_BACK < LOOKBACK_TABLE_SIZE, "" ); + +//static_assert( BIN_SIZE <= REORDER_NUMBER_OF_THREADS_PER_BLOCK, "please check scanExclusive" ); +//static_assert( BIN_SIZE % REORDER_NUMBER_OF_THREADS_PER_BLOCK == 0, "please check prefixSumExclusive on onesweep_reorder" ); }; // namespace Oro \ No newline at end of file diff --git a/ParallelPrimitives/RadixSortKernels.h b/ParallelPrimitives/RadixSortKernels.h index 100553ba..75e94517 100644 --- a/ParallelPrimitives/RadixSortKernels.h +++ b/ParallelPrimitives/RadixSortKernels.h @@ -1,28 +1,10 @@ -// -// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. -// - #include #define LDS_BARRIER __syncthreads() +#if defined( CUDART_VERSION ) && CUDART_VERSION >= 9000 +#define ITS 1 +#endif + namespace { @@ -34,144 +16,6 @@ using u32 = unsigned int; using u64 = unsigned long long; } // namespace -// #define NV_WORKAROUND 1 - -// default values -#if defined( OVERWRITE ) - -constexpr auto COUNT_WG_SIZE{ COUNT_WG_SIZE_VAL }; -constexpr auto SCAN_WG_SIZE{ SCAN_WG_SIZE_VAL }; -constexpr auto SORT_WG_SIZE{ SORT_WG_SIZE_VAL }; -constexpr auto SORT_NUM_WARPS_PER_BLOCK{ SORT_NUM_WARPS_PER_BLOCK_VAL }; - -#else - -constexpr auto COUNT_WG_SIZE{ DEFAULT_COUNT_BLOCK_SIZE }; -constexpr auto SCAN_WG_SIZE{ DEFAULT_SCAN_BLOCK_SIZE }; -constexpr auto SORT_WG_SIZE{ DEFAULT_SORT_BLOCK_SIZE }; -constexpr auto SORT_NUM_WARPS_PER_BLOCK{ DEFAULT_NUM_WARPS_PER_BLOCK }; - -#endif - -#if defined( __CUDACC__ ) -constexpr int SORT_SUBBLOCK_SIZE = 2048; -#else -constexpr int SORT_SUBBLOCK_SIZE = 4096; -#endif - - -__device__ constexpr u32 getMaskedBits( const u32 value, const u32 shift ) noexcept { return ( value >> shift ) & RADIX_MASK; } - -extern "C" __global__ void CountKernel( int* gSrc, int* gDst, int gN, int gNItemsPerWG, const int START_BIT, const int N_WGS_EXECUTED ) -{ - __shared__ int table[BIN_SIZE]; - - for( int i = threadIdx.x; i < BIN_SIZE; i += COUNT_WG_SIZE ) - { - table[i] = 0; - } - - __syncthreads(); - - const int offset = blockIdx.x * gNItemsPerWG; - const int upperBound = ( offset + gNItemsPerWG > gN ) ? gN - offset : gNItemsPerWG; - - for( int i = threadIdx.x; i < upperBound; i += COUNT_WG_SIZE ) - { - const int idx = offset + i; - const int tableIdx = getMaskedBits( gSrc[idx], START_BIT ); - atomicAdd( &table[tableIdx], 1 ); - } - - __syncthreads(); - - for( int i = threadIdx.x; i < BIN_SIZE; i += COUNT_WG_SIZE ) - { - gDst[i * N_WGS_EXECUTED + blockIdx.x] = table[i]; - } -} - -template -struct ScanImpl -{ - __device__ static T exec( T a ) - { - T b = __shfl( a, threadIdx.x - STRIDE ); - if( threadIdx.x >= STRIDE ) a += b; - return ScanImpl::exec( a ); - } -}; - -template -struct ScanImpl -{ - __device__ static T exec( T a ) { return a; } -}; - -template -__device__ void waveScanInclusive( T& a, int width ) -{ -#if 0 - a = ScanImpl::exec( a ); -#else - for( int i = 1; i < width; i *= 2 ) - { - T b = __shfl( a, threadIdx.x - i ); - if( threadIdx.x >= i ) a += b; - } -#endif -} - -template -__device__ T waveScanExclusive( T& a, int width ) -{ - waveScanInclusive( a, width ); - - T sum = __shfl( a, width - 1 ); - a = __shfl( a, threadIdx.x - 1 ); - if( threadIdx.x == 0 ) a = 0; - - return sum; -} - -template -__device__ void ldsScanInclusive( T* lds, int width ) -{ - // The width cannot exceed WG_SIZE - __shared__ T temp[2][WG_SIZE]; - - constexpr int MAX_INDEX = 1; - int outIndex = 0; - int inIndex = 1; - - temp[outIndex][threadIdx.x] = lds[threadIdx.x]; - __syncthreads(); - - for( int i = 1; i < width; i *= 2 ) - { - // Swap in and out index for the buffers - - outIndex = MAX_INDEX - outIndex; - inIndex = MAX_INDEX - outIndex; - - if( threadIdx.x >= i ) - { - temp[outIndex][threadIdx.x] = temp[inIndex][threadIdx.x] + temp[inIndex][threadIdx.x - i]; - } - else - { - temp[outIndex][threadIdx.x] = temp[inIndex][threadIdx.x]; - } - - __syncthreads(); - } - - lds[threadIdx.x] = temp[outIndex][threadIdx.x]; - - // Ensure the results are written in LDS and are observable in a block (workgroup) before return. - __threadfence_block(); -} - template __device__ T ldsScanExclusive( T* lds, int width ) { @@ -355,140 +199,6 @@ __device__ void localSort4bitMulti( int* keys, u32* ldsKeys, int* values, u32* l } } -__device__ void localSort8bitMulti_shared_bin( int* keys, u32* ldsKeys, const int START_BIT ) -{ - __shared__ unsigned table[BIN_SIZE]; - - for( int i = threadIdx.x; i < BIN_SIZE; i += SORT_WG_SIZE ) - { - table[i] = 0U; - } - - LDS_BARRIER; - - for( int i = 0; i < SORT_N_ITEMS_PER_WI; ++i ) - { - const int tableIdx = ( keys[i] >> START_BIT ) & RADIX_MASK; - atomicAdd( &table[tableIdx], 1 ); - } - - LDS_BARRIER; - - int globalSum = 0; - for( int binId = 0; binId < BIN_SIZE; binId += SORT_WG_SIZE * 2 ) - { - unsigned* globalOffset = &table[binId]; - const unsigned currentGlobalSum = ldsScanExclusive( globalOffset, SORT_WG_SIZE * 2 ); - globalOffset[threadIdx.x * 2] += globalSum; - globalOffset[threadIdx.x * 2 + 1] += globalSum; - globalSum += currentGlobalSum; - } - - LDS_BARRIER; - - __shared__ u32 keyBuffer[SORT_WG_SIZE * SORT_N_ITEMS_PER_WI]; - - for( int i = 0; i < SORT_N_ITEMS_PER_WI; ++i ) - { - keyBuffer[threadIdx.x * SORT_N_ITEMS_PER_WI + i] = keys[i]; - } - - LDS_BARRIER; - - if( threadIdx.x == 0 ) - { - for( int i = 0; i < SORT_WG_SIZE * SORT_N_ITEMS_PER_WI; ++i ) - { - const int tableIdx = ( keyBuffer[i] >> START_BIT ) & RADIX_MASK; - const int writeIndex = table[tableIdx]; - - ldsKeys[writeIndex] = keyBuffer[i]; - - ++table[tableIdx]; - } - } - - LDS_BARRIER; - - for( int i = 0; i < SORT_N_ITEMS_PER_WI; ++i ) - { - keys[i] = ldsKeys[threadIdx.x * SORT_N_ITEMS_PER_WI + i]; - } -} - -__device__ void localSort8bitMulti_group( int* keys, u32* ldsKeys, const int START_BIT ) -{ - constexpr auto N_GROUP_SIZE{ N_BINS_8BIT / ( sizeof( u64 ) / sizeof( u16 ) ) }; - - __shared__ union - { - u16 m_ungrouped[SORT_WG_SIZE + 1][N_BINS_8BIT]; - u64 m_grouped[SORT_WG_SIZE + 1][N_GROUP_SIZE]; - } lds; - - for( int i = 0; i < N_GROUP_SIZE; ++i ) - { - lds.m_grouped[threadIdx.x][i] = 0U; - } - - for( int i = 0; i < SORT_N_ITEMS_PER_WI; i++ ) - { - const auto in8bit = ( keys[i] >> START_BIT ) & RADIX_MASK; - ++lds.m_ungrouped[threadIdx.x][in8bit]; - } - - LDS_BARRIER; - - for( int groupId = threadIdx.x; groupId < N_GROUP_SIZE; groupId += SORT_WG_SIZE ) - { - u64 sum = 0U; - for( int i = 0; i < SORT_WG_SIZE; i++ ) - { - const auto current = lds.m_grouped[i][groupId]; - lds.m_grouped[i][groupId] = sum; - sum += current; - } - lds.m_grouped[SORT_WG_SIZE][groupId] = sum; - } - - LDS_BARRIER; - - int globalSum = 0; - for( int binId = 0; binId < N_BINS_8BIT; binId += SORT_WG_SIZE * 2 ) - { - auto* globalOffset = &lds.m_ungrouped[SORT_WG_SIZE][binId]; - const int currentGlobalSum = ldsScanExclusive( globalOffset, SORT_WG_SIZE * 2 ); - globalOffset[threadIdx.x * 2] += globalSum; - globalOffset[threadIdx.x * 2 + 1] += globalSum; - globalSum += currentGlobalSum; - } - - LDS_BARRIER; - - for( int i = 0; i < SORT_N_ITEMS_PER_WI; i++ ) - { - const auto in8bit = ( keys[i] >> START_BIT ) & RADIX_MASK; - const auto offset = lds.m_ungrouped[SORT_WG_SIZE][in8bit]; - const auto rank = lds.m_ungrouped[threadIdx.x][in8bit]++; - - ldsKeys[offset + rank] = keys[i]; - } - - LDS_BARRIER; - - for( int i = 0; i < SORT_N_ITEMS_PER_WI; i++ ) - { - keys[i] = ldsKeys[threadIdx.x * SORT_N_ITEMS_PER_WI + i]; - } -} - -template -__device__ void localSort8bitMulti( int* keys, u32* ldsKeys, int* values, u32* ldsValues, const int START_BIT ) -{ - localSort4bitMulti( keys, ldsKeys, values, ldsValues, START_BIT ); - if( N_RADIX > 4 ) localSort4bitMulti( keys, ldsKeys, values, ldsValues, START_BIT + 4 ); -} - template __device__ void SortSinglePass( int* gSrcKey, int* gSrcVal, int* gDstKey, int* gDstVal, int gN, const int START_BIT, const int END_BIT ) { @@ -543,278 +253,433 @@ extern "C" __global__ void SortSinglePassKernel( int* gSrcKey, int* gDstKey, int extern "C" __global__ void SortSinglePassKVKernel( int* gSrcKey, int* gSrcVal, int* gDstKey, int* gDstVal, int gN, const int START_BIT, const int END_BIT ) { SortSinglePass( gSrcKey, gSrcVal, gDstKey, gDstVal, gN, START_BIT, END_BIT ); } -extern "C" __global__ void ParallelExclusiveScanSingleWG( int* gCount, int* gHistogram, const int N_WGS_EXECUTED ) +using RADIX_SORT_KEY_TYPE = u32; +using RADIX_SORT_VALUE_TYPE = u32; + +#if defined( DESCENDING_ORDER ) +constexpr u32 ORDER_MASK_32 = 0xFFFFFFFF; +constexpr u64 ORDER_MASK_64 = 0xFFFFFFFFFFFFFFFFllu; +#else +constexpr u32 ORDER_MASK_32 = 0; +constexpr u64 ORDER_MASK_64 = 0llu; +#endif + +__device__ constexpr u32 div_round_up( u32 val, u32 divisor ) noexcept { return ( val + divisor - 1 ) / divisor; } + +template +__device__ void clearShared( T* sMem, T value ) { - // Use a single WG. - if( blockIdx.x != 0 ) + for( int i = 0; i < NElement; i += NThread ) { - return; + if( i < NElement ) + { + sMem[i + threadIdx.x] = value; + } } +} - // LDS for the parallel scan of the global sum: - // First we store the sum of the counters of each number to it, - // then we compute the global offset using parallel exclusive scan. - __shared__ int blockBuffer[BIN_SIZE]; +__device__ inline u32 getKeyBits( u32 x ) { return x ^ ORDER_MASK_32; } +__device__ inline u64 getKeyBits( u64 x ) { return x ^ ORDER_MASK_64; } +__device__ inline u32 extractDigit( u32 x, u32 bitLocation ) { return ( x >> bitLocation ) & RADIX_MASK; } +__device__ inline u32 extractDigit( u64 x, u32 bitLocation ) { return (u32)( ( x >> bitLocation ) & RADIX_MASK ); } - // fill the LDS with the local sum +template +__device__ inline T scanExclusive( T prefix, T* sMemIO, int nElement ) +{ + // assert(nElement <= nThreads) + bool active = threadIdx.x < nElement; + T value = active ? sMemIO[threadIdx.x] : 0; + T x = value; - for( int binId = threadIdx.x; binId < BIN_SIZE; binId += WG_SIZE ) + for( u32 offset = 1; offset < nElement; offset <<= 1 ) { - // Do exclusive scan for each segment handled by each WI in a WG - - int localThreadSum = 0; - for( int i = 0; i < N_WGS_EXECUTED; ++i ) + if( active && offset <= threadIdx.x ) { - int current = gCount[binId * N_WGS_EXECUTED + i]; - gCount[binId * N_WGS_EXECUTED + i] = localThreadSum; - - localThreadSum += current; + x += sMemIO[threadIdx.x - offset]; } - // Store the thread local sum to LDS. + __syncthreads(); + + if( active ) + { + sMemIO[threadIdx.x] = x; + } - blockBuffer[binId] = localThreadSum; + __syncthreads(); } - LDS_BARRIER; + T sum = sMemIO[nElement - 1]; - // Do parallel exclusive scan on the LDS + __syncthreads(); - int globalSum = 0; - for( int binId = 0; binId < BIN_SIZE; binId += WG_SIZE * 2 ) + if( active ) { - int* globalOffset = &blockBuffer[binId]; - int currentGlobalSum = ldsScanExclusive( globalOffset, WG_SIZE * 2 ); - globalOffset[threadIdx.x * 2] += globalSum; - globalOffset[threadIdx.x * 2 + 1] += globalSum; - globalSum += currentGlobalSum; + sMemIO[threadIdx.x] = x + prefix - value; } - LDS_BARRIER; + __syncthreads(); - // Add the global offset to the global histogram. + return sum; +} + +extern "C" __global__ void gHistogram( RADIX_SORT_KEY_TYPE* inputs, u32 numberOfInputs, u32* gpSumBuffer, u32 startBits, u32* counter ) +{ + __shared__ u32 localCounters[sizeof( RADIX_SORT_KEY_TYPE )][BIN_SIZE]; - for( int binId = threadIdx.x; binId < BIN_SIZE; binId += WG_SIZE ) + for( int i = 0; i < sizeof( RADIX_SORT_KEY_TYPE ); i++ ) { - for( int i = 0; i < N_WGS_EXECUTED; ++i ) + for( int j = threadIdx.x; j < BIN_SIZE; j += GHISTOGRAM_THREADS_PER_BLOCK ) { - gHistogram[binId * N_WGS_EXECUTED + i] += blockBuffer[binId]; + localCounters[i][j] = 0; } } -} -extern "C" __device__ void WorkgroupSync( int threadId, int blockId, int currentSegmentSum, int* currentGlobalOffset, volatile int* gPartialSum, volatile bool* gIsReady ) -{ - if( threadId == 0 ) + u32 numberOfBlocks = div_round_up( numberOfInputs, GHISTOGRAM_ITEM_PER_BLOCK ); + __shared__ u32 iBlock; + for(;;) { - int offset = 0; + if( threadIdx.x == 0 ) + { + iBlock = atomicInc( counter, 0xFFFFFFFF ); + } + + __syncthreads(); - if( blockId != 0 ) + if( numberOfBlocks <= iBlock ) + break; + + for( int j = 0; j < GHISTOGRAM_ITEMS_PER_THREAD; j++ ) { - while( !gIsReady[blockId - 1] ) + u32 itemIndex = iBlock * GHISTOGRAM_ITEM_PER_BLOCK + threadIdx.x * GHISTOGRAM_ITEMS_PER_THREAD + j; + if( itemIndex < numberOfInputs ) { + auto item = inputs[itemIndex]; + for( int i = 0; i < sizeof( RADIX_SORT_KEY_TYPE ); i++ ) + { + u32 bitLocation = startBits + i * N_RADIX; + u32 bits = extractDigit( getKeyBits( item ), bitLocation ); + atomicInc( &localCounters[i][bits], 0xFFFFFFFF ); + } } - - offset = gPartialSum[blockId - 1]; - - __threadfence(); - - // Reset the value - gIsReady[blockId - 1] = false; } - gPartialSum[blockId] = offset + currentSegmentSum; - - // Ensure that the gIsReady is only modified after the gPartialSum is written. - __threadfence(); - - gIsReady[blockId] = true; + __syncthreads(); + } - *currentGlobalOffset = offset; + for( int i = 0; i < sizeof( RADIX_SORT_KEY_TYPE ); i++ ) + { + scanExclusive( 0, &localCounters[i][0], BIN_SIZE ); } - __syncthreads(); + for( int i = 0; i < sizeof( RADIX_SORT_KEY_TYPE ); i++ ) + { + for( int j = threadIdx.x; j < BIN_SIZE; j += GHISTOGRAM_THREADS_PER_BLOCK ) + { + atomicAdd( &gpSumBuffer[BIN_SIZE * i + j], localCounters[i][j] ); + } + } } -extern "C" __global__ void ParallelExclusiveScanAllWG( int* gCount, int* gHistogram, volatile int* gPartialSum, volatile bool* gIsReady ) +template +__device__ __forceinline__ void onesweep_reorder( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, RADIX_SORT_VALUE_TYPE* inputValues, RADIX_SORT_VALUE_TYPE* outputValues, u32 numberOfInputs, u32* gpSumBuffer, + volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, u32 iteration ) { - // Fill the LDS with the partial sum of each segment - __shared__ int blockBuffer[SCAN_WG_SIZE]; + __shared__ u32 pSum[BIN_SIZE]; - blockBuffer[threadIdx.x] = gCount[blockIdx.x * blockDim.x + threadIdx.x]; + struct SMem + { + struct Phase1 + { + u16 blockHistogram[BIN_SIZE]; + u16 lpSum[BIN_SIZE * REORDER_NUMBER_OF_WARPS]; + }; + struct Phase2 + { + RADIX_SORT_KEY_TYPE elements[RADIX_SORT_BLOCK_SIZE]; + }; + struct Phase3 + { + RADIX_SORT_VALUE_TYPE elements[RADIX_SORT_BLOCK_SIZE]; + u8 buckets[RADIX_SORT_BLOCK_SIZE]; + }; - __syncthreads(); + union + { + Phase1 phase1; + Phase2 phase2; + Phase3 phase3; + } u; + }; + __shared__ SMem smem; - // Do parallel exclusive scan on the LDS + u32 bitLocation = startBits + N_RADIX * iteration; + u32 blockIndex = blockIdx.x; + u32 numberOfBlocks = div_round_up( numberOfInputs, RADIX_SORT_BLOCK_SIZE ); - int currentSegmentSum = ldsScanExclusive( blockBuffer, SCAN_WG_SIZE ); + clearShared( smem.u.phase1.lpSum, 0 ); __syncthreads(); - // Sync all the Workgroups to calculate the global offset. - - __shared__ int currentGlobalOffset; - WorkgroupSync( threadIdx.x, blockIdx.x, currentSegmentSum, ¤tGlobalOffset, gPartialSum, gIsReady ); + RADIX_SORT_KEY_TYPE keys[REORDER_NUMBER_OF_ITEM_PER_THREAD]; + u32 warpOffsets[REORDER_NUMBER_OF_ITEM_PER_THREAD]; - // Write back the result. - - gHistogram[blockIdx.x * blockDim.x + threadIdx.x] = blockBuffer[threadIdx.x] + currentGlobalOffset; -} - -template -__device__ void SortImpl( int* gSrcKey, int* gSrcVal, int* gDstKey, int* gDstVal, int* gHistogram, int numberOfInputs, int gNItemsPerWG, const int START_BIT, const int N_WGS_EXECUTED ) -{ - const int startOffset = blockIdx.x * gNItemsPerWG; - const int nItemInBlock = ( startOffset + gNItemsPerWG > numberOfInputs ) ? numberOfInputs - startOffset : gNItemsPerWG; + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; - struct ElementLocation + for( int i = 0, k = 0; i < REORDER_NUMBER_OF_ITEM_PER_WARP; i += WARP_SIZE, k++ ) { - u32 localSrcIndex : 12; - u32 localOffset : 12; - u32 bucket : 8; - }; + u32 itemIndex = blockIndex * RADIX_SORT_BLOCK_SIZE + warp * REORDER_NUMBER_OF_ITEM_PER_WARP + i + lane; + if( itemIndex < numberOfInputs ) + { + keys[k] = inputKeys[itemIndex]; + } + } + for( int i = 0, k = 0; i < REORDER_NUMBER_OF_ITEM_PER_WARP; i += WARP_SIZE, k++ ) + { + u32 itemIndex = blockIndex * RADIX_SORT_BLOCK_SIZE + warp * REORDER_NUMBER_OF_ITEM_PER_WARP + i + lane; + u32 bucketIndex = extractDigit( getKeyBits( keys[k] ), bitLocation ); - __shared__ u32 globalOffset[BIN_SIZE]; - __shared__ u32 localPrefixSum[BIN_SIZE]; - __shared__ u32 counters[BIN_SIZE]; - __shared__ u32 matchMasks[SORT_NUM_WARPS_PER_BLOCK][BIN_SIZE]; - __shared__ u8 elementBuckets[SORT_SUBBLOCK_SIZE]; - __shared__ ElementLocation elementLocations[SORT_SUBBLOCK_SIZE]; + // check the attendees + u32 broThreads = +#if defined( ITS ) + __ballot_sync( 0xFFFFFFFF, +#else + __ballot( +#endif + itemIndex < numberOfInputs ); - for( int i = threadIdx.x; i < BIN_SIZE; i += SORT_WG_SIZE ) - { - // Note: The size of gHistogram is always BIN_SIZE * N_WGS_EXECUTED - globalOffset[i] = gHistogram[i * N_WGS_EXECUTED + blockIdx.x]; + for( int j = 0; j < N_RADIX; ++j ) + { + u32 bit = ( bucketIndex >> j ) & 0x1; + u32 difference = ( 0xFFFFFFFF * bit ) ^ +#if defined( ITS ) + __ballot_sync( 0xFFFFFFFF, bit != 0 ); +#else + __ballot( bit != 0 ); +#endif + broThreads &= ~difference; + } + + u32 lowerMask = ( 1u << lane ) - 1; + auto digitCount = smem.u.phase1.lpSum[bucketIndex * REORDER_NUMBER_OF_WARPS + warp]; + warpOffsets[k] = digitCount + __popc( broThreads & lowerMask ); + +#if defined( ITS ) + __syncwarp( 0xFFFFFFFF ); +#else + __syncthreads(); +#endif + u32 leaderIdx = __ffs( broThreads ) - 1; + if( lane == leaderIdx ) + { + smem.u.phase1.lpSum[bucketIndex * REORDER_NUMBER_OF_WARPS + warp] = digitCount + __popc( broThreads ); + } +#if defined( ITS ) + __syncwarp( 0xFFFFFFFF ); +#else + __syncthreads(); +#endif } - for( int w = 0; w < SORT_NUM_WARPS_PER_BLOCK; ++w ) + __syncthreads(); + + for( int bucketIndex = threadIdx.x; bucketIndex < BIN_SIZE; bucketIndex += REORDER_NUMBER_OF_THREADS_PER_BLOCK ) { - for( int i = threadIdx.x; i < BIN_SIZE; i += SORT_WG_SIZE ) + u32 s = 0; + for( int warp = 0; warp < REORDER_NUMBER_OF_WARPS; warp++ ) { - matchMasks[w][i] = 0; + s += smem.u.phase1.lpSum[bucketIndex * REORDER_NUMBER_OF_WARPS + warp]; } + smem.u.phase1.blockHistogram[bucketIndex] = s; } + struct ParitionID + { + u64 value : 32; + u64 block : 30; + u64 flag : 2; + }; + auto asPartition = []( u64 x ) + { + ParitionID pa; + memcpy( &pa, &x, sizeof( ParitionID ) ); + return pa; + }; + auto asU64 = []( ParitionID pa ) + { + u64 x; + memcpy( &x, &pa, sizeof( u64 ) ); + return x; + }; + + if( threadIdx.x == 0 && LOOKBACK_TABLE_SIZE <= blockIndex ) + { + // Wait until blockIndex < tail - MAX_LOOK_BACK + LOOKBACK_TABLE_SIZE + while( ( atomicAdd( tailIterator, 0 ) & TAIL_MASK ) - MAX_LOOK_BACK + LOOKBACK_TABLE_SIZE <= blockIndex ) + ; + } __syncthreads(); - for( int j = 0; j < nItemInBlock; j += SORT_SUBBLOCK_SIZE ) + for( int i = threadIdx.x; i < BIN_SIZE; i += REORDER_NUMBER_OF_THREADS_PER_BLOCK ) { - for( int i = threadIdx.x; i < BIN_SIZE; i += SORT_WG_SIZE ) + u32 s = smem.u.phase1.blockHistogram[i]; + int pIndex = BIN_SIZE * ( blockIndex % LOOKBACK_TABLE_SIZE ) + i; + { - counters[i] = 0; - localPrefixSum[i] = 0; + ParitionID pa; + pa.value = s; + pa.block = blockIndex; + pa.flag = 1; + lookBackBuffer[pIndex] = asU64( pa ); } - __syncthreads(); - for( int i = 0; i < SORT_SUBBLOCK_SIZE; i += SORT_WG_SIZE ) + u32 gp = gpSumBuffer[iteration * BIN_SIZE + i]; + + u32 p = 0; + + for( int iBlock = (int)blockIndex - 1; 0 <= iBlock; iBlock-- ) { - const auto itemIndex = blockIdx.x * gNItemsPerWG + j + i + threadIdx.x; - if( itemIndex < numberOfInputs ) + int lookbackIndex = BIN_SIZE * ( iBlock % LOOKBACK_TABLE_SIZE ) + i; + ParitionID pa; + + // when you reach to the maximum, flag must be 2. flagRequire = 0b10 + // Otherwise, flag can be 1 or 2 flagRequire = 0b11 + int flagRequire = MAX_LOOK_BACK == blockIndex - iBlock ? 2 : 3; + + do { - const auto item = gSrcKey[itemIndex]; - const u32 bucketIndex = getMaskedBits( item, START_BIT ); - atomicInc( &localPrefixSum[bucketIndex], 0xFFFFFFFF ); - elementBuckets[i + threadIdx.x] = static_cast(bucketIndex); + pa = asPartition( lookBackBuffer[lookbackIndex] ); + } while( ( pa.flag & flagRequire ) == 0 || pa.block != iBlock ); + + u32 value = pa.value; + p += value; + if( pa.flag == 2 ) + { + break; } } - __syncthreads(); + ParitionID pa; + pa.value = p + s; + pa.block = blockIndex; + pa.flag = 2; + lookBackBuffer[pIndex] = asU64( pa ); - ldsScanExclusive( localPrefixSum, BIN_SIZE ); + // complete global output location + u32 globalOutput = gp + p; + pSum[i] = globalOutput; + } - __syncthreads(); + __syncthreads(); - for( int i = 0; i < SORT_SUBBLOCK_SIZE; i += SORT_WG_SIZE ) - { - const auto itemIndex = blockIdx.x * gNItemsPerWG + j + i + threadIdx.x; - const u32 bucketIndex = elementBuckets[i + threadIdx.x]; + if( threadIdx.x == 0 ) + { + while( ( atomicAdd( tailIterator, 0 ) & TAIL_MASK ) != ( blockIndex & TAIL_MASK ) ) + ; - const int warp = threadIdx.x / 32; - const int lane = threadIdx.x % 32; + atomicInc( tailIterator, numberOfBlocks - 1 /* after the vary last item, it will be zero */ ); + } - __syncthreads(); + __syncthreads(); - if( itemIndex < numberOfInputs ) - { - atomicOr( &matchMasks[warp][bucketIndex], 1u << lane ); - } + u32 prefix = 0; + for( int i = 0; i < BIN_SIZE; i += REORDER_NUMBER_OF_THREADS_PER_BLOCK ) + { + prefix += scanExclusive( prefix, smem.u.phase1.blockHistogram + i, min( REORDER_NUMBER_OF_THREADS_PER_BLOCK, BIN_SIZE ) ); + } - __syncthreads(); + for( int bucketIndex = threadIdx.x; bucketIndex < BIN_SIZE; bucketIndex += REORDER_NUMBER_OF_THREADS_PER_BLOCK ) + { + u32 s = smem.u.phase1.blockHistogram[bucketIndex]; - bool flushMask = false; + pSum[bucketIndex] -= s; // pre-substruct to avoid pSum[bucketIndex] + i - smem.u.phase1.blockHistogram[bucketIndex] to calculate destinations - if( itemIndex < numberOfInputs ) - { - const u32 matchMask = matchMasks[warp][bucketIndex]; - const u32 lowerMask = ( 1u << lane ) - 1; - u32 offset = __popc( matchMask & lowerMask ); + for( int w = 0; w < REORDER_NUMBER_OF_WARPS; w++ ) + { + int index = bucketIndex * REORDER_NUMBER_OF_WARPS + w; + u32 n = smem.u.phase1.lpSum[index]; + smem.u.phase1.lpSum[index] = s; + s += n; + } + } - flushMask = ( offset == 0 ); + __syncthreads(); - for( int w = 0; w < warp; ++w ) - { - offset += __popc( matchMasks[w][bucketIndex] ); - } + for( int k = 0; k < REORDER_NUMBER_OF_ITEM_PER_THREAD; k++ ) + { + u32 bucketIndex = extractDigit( getKeyBits( keys[k] ), bitLocation ); + warpOffsets[k] += smem.u.phase1.lpSum[bucketIndex * REORDER_NUMBER_OF_WARPS + warp]; + } - const u32 localOffset = counters[bucketIndex] + offset; - const u32 to = localOffset + localPrefixSum[bucketIndex]; + __syncthreads(); - ElementLocation el; - el.localSrcIndex = i + threadIdx.x; - el.localOffset = localOffset; - el.bucket = bucketIndex; - elementLocations[to] = el; - } + for( int i = lane, k = 0; i < REORDER_NUMBER_OF_ITEM_PER_WARP; i += WARP_SIZE, k++ ) + { + u32 itemIndex = blockIndex * RADIX_SORT_BLOCK_SIZE + warp * REORDER_NUMBER_OF_ITEM_PER_WARP + i; + u32 bucketIndex = extractDigit( getKeyBits( keys[k] ), bitLocation ); + if( itemIndex < numberOfInputs ) + { + smem.u.phase2.elements[warpOffsets[k]] = keys[k]; + } + } - __syncthreads(); + __syncthreads(); - if( itemIndex < numberOfInputs ) - { - atomicInc( &counters[bucketIndex], 0xFFFFFFFF ); - } + for( int i = threadIdx.x; i < RADIX_SORT_BLOCK_SIZE; i += REORDER_NUMBER_OF_THREADS_PER_BLOCK ) + { + u32 itemIndex = blockIndex * RADIX_SORT_BLOCK_SIZE + i; + if( itemIndex < numberOfInputs ) + { + auto item = smem.u.phase2.elements[i]; + u32 bucketIndex = extractDigit( getKeyBits( item ), bitLocation ); - if( flushMask ) - { - matchMasks[warp][bucketIndex] = 0; - } + // u32 dstIndex = pSum[bucketIndex] + i - smem.u.phase1.blockHistogram[bucketIndex]; + u32 dstIndex = pSum[bucketIndex] + i; + outputKeys[dstIndex] = item; } + } + + if constexpr( keyPair ) + { + __syncthreads(); - for( int i = 0; i < SORT_SUBBLOCK_SIZE; i += SORT_WG_SIZE ) + for( int i = lane, k = 0; i < REORDER_NUMBER_OF_ITEM_PER_WARP; i += WARP_SIZE, k++ ) { - const int itemIndex = blockIdx.x * gNItemsPerWG + j + i + threadIdx.x; + u32 itemIndex = blockIndex * RADIX_SORT_BLOCK_SIZE + warp * REORDER_NUMBER_OF_ITEM_PER_WARP + i; + u32 bucketIndex = extractDigit( getKeyBits( keys[k] ), bitLocation ); if( itemIndex < numberOfInputs ) { - const auto el = elementLocations[i + threadIdx.x]; - const auto srcIndex = blockIdx.x * gNItemsPerWG + j + el.localSrcIndex; - const auto bucketIndex = el.bucket; - - const auto dstIndex = globalOffset[bucketIndex] + el.localOffset; - gDstKey[dstIndex] = gSrcKey[srcIndex]; - - if constexpr( KEY_VALUE_PAIR ) - { - gDstVal[dstIndex] = gSrcVal[srcIndex]; - } + smem.u.phase3.elements[warpOffsets[k]] = inputValues[itemIndex]; + smem.u.phase3.buckets[warpOffsets[k]] = bucketIndex; } } __syncthreads(); - for( int i = threadIdx.x; i < BIN_SIZE; i += SORT_WG_SIZE ) + for( int i = threadIdx.x; i < RADIX_SORT_BLOCK_SIZE; i += REORDER_NUMBER_OF_THREADS_PER_BLOCK ) { - globalOffset[i] += counters[i]; - } + u32 itemIndex = blockIndex * RADIX_SORT_BLOCK_SIZE + i; + if( itemIndex < numberOfInputs ) + { + auto item = smem.u.phase3.elements[i]; + u32 bucketIndex = smem.u.phase3.buckets[i]; - __syncthreads(); + // u32 dstIndex = pSum[bucketIndex] + i - smem.u.phase1.blockHistogram[bucketIndex]; + u32 dstIndex = pSum[bucketIndex] + i; + outputValues[dstIndex] = item; + } + } } } - -extern "C" __global__ void SortKernel( int* gSrcKey, int* gDstKey, int* gHistogram, int gN, int gNItemsPerWG, const int START_BIT, const int N_WGS_EXECUTED ) +extern "C" __global__ void __launch_bounds__( REORDER_NUMBER_OF_THREADS_PER_BLOCK ) onesweep_reorderKey64( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, u32 numberOfInputs, u32* gpSumBuffer, volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, + u32 iteration ) { - SortImpl( gSrcKey, nullptr, gDstKey, nullptr, gHistogram, gN, gNItemsPerWG, START_BIT, N_WGS_EXECUTED ); + onesweep_reorder( inputKeys, outputKeys, nullptr, nullptr, numberOfInputs, gpSumBuffer, lookBackBuffer, tailIterator, startBits, iteration ); } - -extern "C" __global__ void SortKVKernel( int* gSrcKey, int* gSrcVal, int* gDstKey, int* gDstVal, int* gHistogram, int gN, int gNItemsPerWG, const int START_BIT, const int N_WGS_EXECUTED ) +extern "C" __global__ void __launch_bounds__( REORDER_NUMBER_OF_THREADS_PER_BLOCK ) onesweep_reorderKeyPair64( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, RADIX_SORT_VALUE_TYPE* inputValues, RADIX_SORT_VALUE_TYPE* outputValues, + u32 numberOfInputs, + u32* gpSumBuffer, + volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, u32 iteration ) { - SortImpl( gSrcKey, gSrcVal, gDstKey, gDstVal, gHistogram, gN, gNItemsPerWG, START_BIT, N_WGS_EXECUTED ); -} + onesweep_reorder( inputKeys, outputKeys, inputValues, outputValues, numberOfInputs, gpSumBuffer, lookBackBuffer, tailIterator, startBits, iteration ); +} \ No newline at end of file diff --git a/Test/RadixSort/main.cpp b/Test/RadixSort/main.cpp index 645480fd..090da603 100644 --- a/Test/RadixSort/main.cpp +++ b/Test/RadixSort/main.cpp @@ -68,6 +68,7 @@ class SortTest OrochiUtils::malloc( dstGpu.key, testSize ); std::vector srcKey( testSize ); + for( int i = 0; i < testSize; i++ ) { srcKey[i] = getRandom( 0u, (u32)( ( 1ull << (u64)testBits ) - 1 ) ); @@ -85,7 +86,6 @@ class SortTest } } - Stopwatch sw; for( int i = 0; i < nRuns; i++ ) { OrochiUtils::copyHtoD( srcGpu.key, srcKey.data(), testSize ); @@ -97,7 +97,8 @@ class SortTest OrochiUtils::waitForCompletion(); } - sw.start(); + OroStopwatch oroStream( nullptr ); + oroStream.start(); if constexpr( KEY_VALUE_PAIR ) { @@ -108,9 +109,10 @@ class SortTest m_sort.sort( srcGpu.key, dstGpu.key, testSize, 0, testBits ); } + oroStream.stop(); + OrochiUtils::waitForCompletion(); - sw.stop(); - float ms = sw.getMs(); + float ms = oroStream.getMs(); float gKeys_s = static_cast( testSize ) / 1000.f / 1000.f / ms; printf( "%5.2fms (%3.2fGKeys/s) sorting %3.1fMkeys [%s]\n", ms, gKeys_s, testSize / 1000.f / 1000.f, KEY_VALUE_PAIR ? "keyValue" : "key" ); } @@ -290,6 +292,7 @@ enum TestType TEST_SIMPLE, TEST_PERF, TEST_BITS, + TEST_CAPTURE, TEST_MISC, }; @@ -370,7 +373,11 @@ int main( int argc, char** argv ) sort.test( testSize, 32, nRuns ); } break; - + case TEST_CAPTURE: + { + sort.test( 1u << 27 /*2^29*/, 32, 9999999 ); + } + break; case TEST_MISC: { static constexpr auto file = "input.txt"; From c61000573356f3bedc206c54ca45118c86f03a3f Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Wed, 19 Feb 2025 09:22:27 +0900 Subject: [PATCH 02/11] Removing shadwing vars --- ParallelPrimitives/RadixSort.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 62c74159..47fc4790 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -148,9 +148,6 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string } else { - const auto includeArg{ "-I" + currentIncludeDir }; - std::vector opts; - opts.push_back( includeArg.c_str() ); oroFunctions[record.kernelType] = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts ); } From 8aa523c267377827e2670fb9859c594ecc1f246e Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Wed, 19 Feb 2025 09:57:05 +0900 Subject: [PATCH 03/11] Updating compilation options --- ParallelPrimitives/RadixSort.cpp | 61 +++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 47fc4790..26d9289c 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -5,40 +5,66 @@ #include #include -#if defined( ORO_PP_LOAD_FROM_STRING ) - +// if ORO_PP_LOAD_FROM_STRING && ORO_PRECOMPILED -> we load the precompiled/baked kernels. +// if ORO_PP_LOAD_FROM_STRING && NOT ORO_PRECOMPILED -> we load the baked source code kernels (from Kernels.h / KernelArgs.h) +#if !defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) // Note: the include order must be in this particular form. // clang-format off #include #include // clang-format on +#else +// if Kernels.h / KernelArgs.h are not included, declare nullptr strings +static const char* hip_RadixSortKernels = nullptr; +namespace hip +{ +static const char** RadixSortKernelsArgs = nullptr; +static const char** RadixSortKernelsIncludes = nullptr; +} // namespace hip #endif #if defined( __GNUC__ ) #include #endif +#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) +#include // generate this header with 'convert_binary_to_array.py' +#else +const unsigned char oro_compiled_kernels_h[] = ""; +const size_t oro_compiled_kernels_h_size = 0; +#endif + constexpr uint64_t div_round_up64( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; } constexpr uint64_t next_multiple64( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64( val, divisor ) * divisor; } namespace { + +// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode. +#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING ) + +// this flag means that we bake the precompiled kernels +constexpr auto usePrecompiledAndBakedKernel = true; + +constexpr auto useBitCode = false; +constexpr auto useBakeKernel = false; + +#else + +constexpr auto usePrecompiledAndBakedKernel = false; + #if defined( ORO_PRECOMPILED ) -constexpr auto useBitCode = true; +constexpr auto useBitCode = true; // this flag means we use the bitcode file #else constexpr auto useBitCode = false; #endif #if defined( ORO_PP_LOAD_FROM_STRING ) -constexpr auto useBakeKernel = true; +constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string ) #else constexpr auto useBakeKernel = false; -static const char* hip_RadixSortKernels = nullptr; -namespace hip -{ -static const char** RadixSortKernelsArgs = nullptr; -static const char** RadixSortKernelsIncludes = nullptr; -} // namespace hip +#endif + #endif static_assert( !( useBitCode && useBakeKernel ), "useBitCode and useBakeKernel cannot coexist" ); @@ -138,11 +164,15 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string for( const auto& record : records ) { -#if defined( ORO_PP_LOAD_FROM_STRING ) - oroFunctions[record.kernelType] = oroutils.getFunctionFromString( device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes ); -#else - - if constexpr( useBitCode ) + if constexpr( usePrecompiledAndBakedKernel ) + { + oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() ); + } + else if constexpr( useBakeKernel ) + { + oroFunctions[record.kernelType] = m_oroutils.getFunctionFromString( m_device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes ); + } + else if constexpr( useBitCode ) { oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary( binaryPath.c_str(), record.kernelName.c_str() ); } @@ -151,7 +181,6 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string oroFunctions[record.kernelType] = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts ); } -#endif if( m_flags == Flag::LOG ) { printKernelInfo( record.kernelName, oroFunctions[record.kernelType] ); From 745817416b2fba955f7d22271e0fd0e038630363 Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Wed, 19 Feb 2025 10:23:26 +0900 Subject: [PATCH 04/11] Support for bitcode and other compilation options --- ParallelPrimitives/RadixSort.cpp | 23 +++++++++-------------- ParallelPrimitives/RadixSort.h | 3 +++ ParallelPrimitives/RadixSortKernels.h | 12 ++++++------ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 26d9289c..9ad4501e 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -157,9 +157,11 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string Kernel kernelType; }; - const std::vector records{ - { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV }, - }; + const std::vector records{ { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, + { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV }, + { "GHistogram", Kernel::SORT_GHISTOGRAM }, + { "OnesweepReorderKey64", Kernel::SORT_ONESWEEP_REORDER_KEY_64 }, + { "OnesweepReorderKeyPair64", Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64 } }; for( const auto& record : records ) @@ -187,13 +189,6 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string } } - // TODO: bit code support? -#define LOAD_FUNC( var, kernel ) var = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), kernel, &opts ); - LOAD_FUNC( m_gHistogram, "gHistogram" ); - LOAD_FUNC( m_onesweep_reorderKey64, "onesweep_reorderKey64" ); - LOAD_FUNC( m_onesweep_reorderKeyPair64, "onesweep_reorderKeyPair64" ); -#undef LOAD_FUNC - } void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept @@ -245,11 +240,11 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n // counter for gHistogram. { int maxBlocksPerMP = 0; - oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, m_gHistogram, GHISTOGRAM_THREADS_PER_BLOCK, 0 ); + oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 ); const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048; const void* args[] = { &src.key, &n, arg_cast( m_gpSumBuffer.address() ), &startBit, arg_cast( m_gpSumCounter.address() ) }; - OrochiUtils::launch1D( m_gHistogram, nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0, stream ); + OrochiUtils::launch1D( oroFunctions[Kernel::SORT_GHISTOGRAM], nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0, stream ); } auto s = src; @@ -264,12 +259,12 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n if( keyPair ) { const void* args[] = { &s.key, &d.key, &s.value, &d.value, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i }; - OrochiUtils::launch1D( m_onesweep_reorderKeyPair64, numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream ); + OrochiUtils::launch1D( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream ); } else { const void* args[] = { &s.key, &d.key, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i }; - OrochiUtils::launch1D( m_onesweep_reorderKey64, numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream ); + OrochiUtils::launch1D( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream ); } std::swap( s, d ); } diff --git a/ParallelPrimitives/RadixSort.h b/ParallelPrimitives/RadixSort.h index a20c9f0c..dd62f107 100644 --- a/ParallelPrimitives/RadixSort.h +++ b/ParallelPrimitives/RadixSort.h @@ -65,6 +65,9 @@ class RadixSort final { SORT_SINGLE_PASS, SORT_SINGLE_PASS_KV, + SORT_GHISTOGRAM, + SORT_ONESWEEP_REORDER_KEY_64, + SORT_ONESWEEP_REORDER_KEY_PAIR_64 }; std::unordered_map oroFunctions; diff --git a/ParallelPrimitives/RadixSortKernels.h b/ParallelPrimitives/RadixSortKernels.h index 75e94517..57311b3c 100644 --- a/ParallelPrimitives/RadixSortKernels.h +++ b/ParallelPrimitives/RadixSortKernels.h @@ -322,7 +322,7 @@ __device__ inline T scanExclusive( T prefix, T* sMemIO, int nElement ) return sum; } -extern "C" __global__ void gHistogram( RADIX_SORT_KEY_TYPE* inputs, u32 numberOfInputs, u32* gpSumBuffer, u32 startBits, u32* counter ) +extern "C" __global__ void GHistogram( RADIX_SORT_KEY_TYPE* inputs, u32 numberOfInputs, u32* gpSumBuffer, u32 startBits, u32* counter ) { __shared__ u32 localCounters[sizeof( RADIX_SORT_KEY_TYPE )][BIN_SIZE]; @@ -381,7 +381,7 @@ extern "C" __global__ void gHistogram( RADIX_SORT_KEY_TYPE* inputs, u32 numberOf } template -__device__ __forceinline__ void onesweep_reorder( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, RADIX_SORT_VALUE_TYPE* inputValues, RADIX_SORT_VALUE_TYPE* outputValues, u32 numberOfInputs, u32* gpSumBuffer, +__device__ __forceinline__ void OnesweepReorder( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, RADIX_SORT_VALUE_TYPE* inputValues, RADIX_SORT_VALUE_TYPE* outputValues, u32 numberOfInputs, u32* gpSumBuffer, volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, u32 iteration ) { __shared__ u32 pSum[BIN_SIZE]; @@ -671,15 +671,15 @@ __device__ __forceinline__ void onesweep_reorder( RADIX_SORT_KEY_TYPE* inputKeys } } } -extern "C" __global__ void __launch_bounds__( REORDER_NUMBER_OF_THREADS_PER_BLOCK ) onesweep_reorderKey64( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, u32 numberOfInputs, u32* gpSumBuffer, volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, +extern "C" __global__ void __launch_bounds__( REORDER_NUMBER_OF_THREADS_PER_BLOCK ) OnesweepReorderKey64( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, u32 numberOfInputs, u32* gpSumBuffer, volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, u32 iteration ) { - onesweep_reorder( inputKeys, outputKeys, nullptr, nullptr, numberOfInputs, gpSumBuffer, lookBackBuffer, tailIterator, startBits, iteration ); + OnesweepReorder( inputKeys, outputKeys, nullptr, nullptr, numberOfInputs, gpSumBuffer, lookBackBuffer, tailIterator, startBits, iteration ); } -extern "C" __global__ void __launch_bounds__( REORDER_NUMBER_OF_THREADS_PER_BLOCK ) onesweep_reorderKeyPair64( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, RADIX_SORT_VALUE_TYPE* inputValues, RADIX_SORT_VALUE_TYPE* outputValues, +extern "C" __global__ void __launch_bounds__( REORDER_NUMBER_OF_THREADS_PER_BLOCK ) OnesweepReorderKeyPair64( RADIX_SORT_KEY_TYPE* inputKeys, RADIX_SORT_KEY_TYPE* outputKeys, RADIX_SORT_VALUE_TYPE* inputValues, RADIX_SORT_VALUE_TYPE* outputValues, u32 numberOfInputs, u32* gpSumBuffer, volatile u64* lookBackBuffer, u32* tailIterator, u32 startBits, u32 iteration ) { - onesweep_reorder( inputKeys, outputKeys, inputValues, outputValues, numberOfInputs, gpSumBuffer, lookBackBuffer, tailIterator, startBits, iteration ); + OnesweepReorder( inputKeys, outputKeys, inputValues, outputValues, numberOfInputs, gpSumBuffer, lookBackBuffer, tailIterator, startBits, iteration ); } \ No newline at end of file From 06b465d12c46786d9763491d13af2988fa758b18 Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Wed, 19 Feb 2025 11:13:23 +0900 Subject: [PATCH 05/11] Oroutils warning resolved --- Test/RadixSort/main.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Test/RadixSort/main.cpp b/Test/RadixSort/main.cpp index 090da603..815da77a 100644 --- a/Test/RadixSort/main.cpp +++ b/Test/RadixSort/main.cpp @@ -389,6 +389,8 @@ int main( int argc, char** argv ) break; }; + oroutils.unloadKernelCache(); + printf( ">> done\n" ); return 0; } From 9bfb5d63851c1187cee5889fb2bf2c129ed180ab Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Wed, 19 Feb 2025 14:03:42 +0900 Subject: [PATCH 06/11] Occupancy function fix --- ParallelPrimitives/RadixSort.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 9ad4501e..1c96cf5b 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -240,7 +240,7 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n // counter for gHistogram. { int maxBlocksPerMP = 0; - oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 ); + oroError e = oroModuleOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 ); const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048; const void* args[] = { &src.key, &n, arg_cast( m_gpSumBuffer.address() ), &startBit, arg_cast( m_gpSumCounter.address() ) }; From 36f032845b1c80c17c6b1523bd651ff0eca181f2 Mon Sep 17 00:00:00 2001 From: Daniel Meister Date: Mon, 24 Mar 2025 13:50:12 +0900 Subject: [PATCH 07/11] License --- ParallelPrimitives/RadixSort.cpp | 22 ++++++++++++++++++++++ ParallelPrimitives/RadixSort.h | 22 ++++++++++++++++++++++ ParallelPrimitives/RadixSortConfigs.h | 22 ++++++++++++++++++++++ ParallelPrimitives/RadixSortKernels.h | 22 ++++++++++++++++++++++ 4 files changed, 88 insertions(+) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 1c96cf5b..73b9268a 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -1,3 +1,25 @@ +// +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + #include #include #include diff --git a/ParallelPrimitives/RadixSort.h b/ParallelPrimitives/RadixSort.h index dd62f107..3b044d3c 100644 --- a/ParallelPrimitives/RadixSort.h +++ b/ParallelPrimitives/RadixSort.h @@ -1,3 +1,25 @@ +// +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + #pragma once #include diff --git a/ParallelPrimitives/RadixSortConfigs.h b/ParallelPrimitives/RadixSortConfigs.h index 40cd1120..d06cde55 100644 --- a/ParallelPrimitives/RadixSortConfigs.h +++ b/ParallelPrimitives/RadixSortConfigs.h @@ -1,3 +1,25 @@ +// +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + #pragma once namespace Oro diff --git a/ParallelPrimitives/RadixSortKernels.h b/ParallelPrimitives/RadixSortKernels.h index 57311b3c..3fe37293 100644 --- a/ParallelPrimitives/RadixSortKernels.h +++ b/ParallelPrimitives/RadixSortKernels.h @@ -1,3 +1,25 @@ +// +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + #include #define LDS_BARRIER __syncthreads() From 571b423666f396bf9d5f118ccaf63fa29b908983 Mon Sep 17 00:00:00 2001 From: Chih-Chen Kao Date: Mon, 24 Mar 2025 14:27:55 +0100 Subject: [PATCH 08/11] Use constexpr to calculate "key_type_size" Signed-off-by: Chih-Chen Kao --- ParallelPrimitives/RadixSort.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 73b9268a..7b64e12d 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -217,15 +217,18 @@ void RadixSort::configure( const std::string& kernelPath, const std::string& inc { compileKernels( kernelPath, includeDir ); - u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * sizeof( u32 /* key type */ ); - m_gpSumBuffer.resizeAsync( gpSumBuffer, false /*copy*/, stream ); + constexpr bool enable_copying = false; + constexpr auto key_type_size = sizeof(std::remove_pointer_t); + + constexpr u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * key_type_size; + m_gpSumBuffer.resizeAsync( gpSumBuffer, enable_copying /*copy*/, stream ); u64 lookBackBuffer = sizeof( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE ); - m_lookbackBuffer.resizeAsync( lookBackBuffer, false /*copy*/, stream ); + m_lookbackBuffer.resizeAsync( lookBackBuffer, enable_copying /*copy*/, stream ); - m_tailIterator.resizeAsync( 1, false /*copy*/, stream ); + m_tailIterator.resizeAsync( 1, enable_copying /*copy*/, stream ); m_tailIterator.resetAsync( stream ); - m_gpSumCounter.resizeAsync( 1, false /*copy*/, stream ); + m_gpSumCounter.resizeAsync( 1, enable_copying /*copy*/, stream ); } void RadixSort::setFlag( Flag flag ) noexcept { m_flags = flag; } From 34aa200aa70d183c085bbf88db1d9c0057cf228f Mon Sep 17 00:00:00 2001 From: Chih-Chen Kao Date: Mon, 24 Mar 2025 14:29:16 +0100 Subject: [PATCH 09/11] introduce a constexpr variable "bit_per_iteration" since it's a 8-bit radix sort Signed-off-by: Chih-Chen Kao --- ParallelPrimitives/RadixSort.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 7b64e12d..bfa928b8 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -255,7 +255,9 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n return; } - int nIteration = div_round_up64( endBit - startBit, 8 ); + constexpr uint64_t bit_per_iteration = 8ULL; + + int nIteration = div_round_up64( endBit - startBit, bit_per_iteration); uint64_t numberOfBlocks = div_round_up64( n, RADIX_SORT_BLOCK_SIZE ); m_lookbackBuffer.resetAsync( stream ); @@ -274,7 +276,7 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n auto s = src; auto d = dst; - for( int i = 0; i < nIteration; i++ ) + for( int i = 0; i < nIteration; ++i ) { if( numberOfBlocks < LOOKBACK_TABLE_SIZE * 2 ) { From a052b62fa436dc1bed268f706cfe4a358c3951b7 Mon Sep 17 00:00:00 2001 From: Chih-Chen Kao Date: Mon, 24 Mar 2025 14:35:26 +0100 Subject: [PATCH 10/11] Use m_oroutils.copyDtoDAsync Signed-off-by: Chih-Chen Kao --- ParallelPrimitives/RadixSort.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index bfa928b8..914cb438 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -298,11 +298,11 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n if( s.key == src.key ) { - oroMemcpyDtoDAsync( (oroDeviceptr)dst.key, (oroDeviceptr)src.key, sizeof( uint32_t ) * n, stream ); + m_oroutils.copyDtoDAsync(dst.key, src.key, n, stream); if( keyPair ) { - oroMemcpyDtoDAsync( (oroDeviceptr)dst.value, (oroDeviceptr)src.value, sizeof( uint32_t ) * n, stream ); + m_oroutils.copyDtoDAsync(dst.value, src.value, n, stream); } } } From e8e2d14bb2cdb06a4b01ba735178b4d241fc6f0c Mon Sep 17 00:00:00 2001 From: Chih-Chen Kao Date: Mon, 24 Mar 2025 14:36:06 +0100 Subject: [PATCH 11/11] fix format Signed-off-by: Chih-Chen Kao --- ParallelPrimitives/RadixSort.cpp | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/ParallelPrimitives/RadixSort.cpp b/ParallelPrimitives/RadixSort.cpp index 914cb438..dcef2f4b 100644 --- a/ParallelPrimitives/RadixSort.cpp +++ b/ParallelPrimitives/RadixSort.cpp @@ -76,7 +76,7 @@ constexpr auto useBakeKernel = false; constexpr auto usePrecompiledAndBakedKernel = false; #if defined( ORO_PRECOMPILED ) -constexpr auto useBitCode = true; // this flag means we use the bitcode file +constexpr auto useBitCode = true; // this flag means we use the bitcode file #else constexpr auto useBitCode = false; #endif @@ -185,12 +185,11 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string { "OnesweepReorderKey64", Kernel::SORT_ONESWEEP_REORDER_KEY_64 }, { "OnesweepReorderKeyPair64", Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64 } }; - for( const auto& record : records ) { - if constexpr( usePrecompiledAndBakedKernel ) - { - oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() ); + if constexpr( usePrecompiledAndBakedKernel ) + { + oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() ); } else if constexpr( useBakeKernel ) { @@ -210,7 +209,6 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string printKernelInfo( record.kernelName, oroFunctions[record.kernelType] ); } } - } void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept @@ -218,7 +216,7 @@ void RadixSort::configure( const std::string& kernelPath, const std::string& inc compileKernels( kernelPath, includeDir ); constexpr bool enable_copying = false; - constexpr auto key_type_size = sizeof(std::remove_pointer_t); + constexpr auto key_type_size = sizeof( std::remove_pointer_t ); constexpr u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * key_type_size; m_gpSumBuffer.resizeAsync( gpSumBuffer, enable_copying /*copy*/, stream ); @@ -257,14 +255,14 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n constexpr uint64_t bit_per_iteration = 8ULL; - int nIteration = div_round_up64( endBit - startBit, bit_per_iteration); + int nIteration = div_round_up64( endBit - startBit, bit_per_iteration ); uint64_t numberOfBlocks = div_round_up64( n, RADIX_SORT_BLOCK_SIZE ); m_lookbackBuffer.resetAsync( stream ); m_gpSumCounter.resetAsync( stream ); m_gpSumBuffer.resetAsync( stream ); - // counter for gHistogram. + // counter for gHistogram. { int maxBlocksPerMP = 0; oroError e = oroModuleOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 ); @@ -298,17 +296,14 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n if( s.key == src.key ) { - m_oroutils.copyDtoDAsync(dst.key, src.key, n, stream); + m_oroutils.copyDtoDAsync( dst.key, src.key, n, stream ); if( keyPair ) { - m_oroutils.copyDtoDAsync(dst.value, src.value, n, stream); + m_oroutils.copyDtoDAsync( dst.value, src.value, n, stream ); } } } -void RadixSort::sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept -{ - sort( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream ); -} +void RadixSort::sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept { sort( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream ); } }; // namespace Oro