Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
d1744fa
Implement batched random integer generation for shuffle
FranciscoThiesen Dec 2, 2025
3df8856
Address review feedback: extract URNG range check to variable template
FranciscoThiesen Dec 3, 2025
8aa184a
Apply clang-format
FranciscoThiesen Dec 3, 2025
3a8bff6
Merge branch 'main' into FranciscoGeimanThiesen/faster_shuffle_batche…
StephanTLavavej Feb 21, 2026
b838edc
Naming: `_B1` => `_Bx1`, `_B2` => `_Bx2`
StephanTLavavej Feb 21, 2026
3091d9d
Scope `_Diff` more tightly.
StephanTLavavej Feb 21, 2026
9c63953
Add `const` to `_UFirst`.
StephanTLavavej Feb 21, 2026
bbad415
Drop unnecessary `static_cast<uint64_t>` for `_Ref()`, add internal s…
StephanTLavavej Feb 21, 2026
592335d
Add `const` to `_Bound`, `_Bound1`, `_Bound2` params.
StephanTLavavej Feb 21, 2026
d9c6d6b
`_Batch_2()` should take a reference-to-array.
StephanTLavavej Feb 21, 2026
0b9e0ce
Extract `static_cast<uint64_t>(_Bound)` as `_Bx`.
StephanTLavavej Feb 21, 2026
6706d3b
Drop unused bounds, some with mathematically incorrect comments.
StephanTLavavej Feb 21, 2026
1655958
Comment nitpicks: RNGs => URNGs, RNG => URNG
StephanTLavavej Feb 21, 2026
c5019c5
Extract `original`, add `const`.
StephanTLavavej Feb 21, 2026
318f186
Avoid unnecessary `sorted_v` copy.
StephanTLavavej Feb 21, 2026
237cf7b
Repeat comments for clarity.
StephanTLavavej Feb 21, 2026
dfbba0e
Condense code to make the pattern clearer on a single screen.
StephanTLavavej Feb 21, 2026
82d9e67
Extract `shuffle_is_a_permutation()`, test a large odd size (1729).
StephanTLavavej Feb 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 254 additions & 8 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -6178,6 +6178,87 @@ namespace ranges {
#endif // _HAS_CXX20
#endif // _HAS_CXX17

// Batched random integer generation for shuffle optimization.
// From Nevin Brackett-Rozinsky and Daniel Lemire, "Batched Ranged Random Integer Generation",
// Software: Practice and Experience 55(1), 2025.
//
// This algorithm extracts multiple bounded random integers from a single 64-bit random word,
// using only multiplication (no division) in the common case.

// Check if a URNG has full 64-bit range [0, 2^64 - 1].
// Batched random generation is only beneficial for such URNGs.
template <class _Urng>
constexpr bool _Urng_has_full_64bit_range =
is_same_v<_Invoke_result_t<_Urng&>, uint64_t> && (_Urng::min) () == 0 && (_Urng::max) () == _Max_limit<uint64_t>();

template <class _Diff, class _Urng>
struct _Batched_rng_from_urng {
_STL_INTERNAL_STATIC_ASSERT(_Urng_has_full_64bit_range<_Urng>);

// Threshold bounds for batch sizes based on array size.
// These are derived from the paper to minimize expected cost per random value.
// Batch size k requires product of k consecutive bounds <= 2^64.
static constexpr uint64_t _Bound_for_batch_2 = 4294967296; // 2^32, for batch of 2

_Urng& _Ref;

explicit _Batched_rng_from_urng(_Urng& _Func) noexcept : _Ref(_Func) {}

// Generate a single bounded random value in [0, _Bound) using Lemire's method.
_NODISCARD _Diff _Single_bounded(const _Diff _Bound) {
const uint64_t _Bx = static_cast<uint64_t>(_Bound);
_Unsigned128 _Product{_Base128::_UMul128(_Ref(), _Bx, _Product._Word[1])};
auto _Leftover = _Product._Word[0];

if (_Leftover < _Bx) {
const uint64_t _Threshold = (0 - _Bx) % _Bx;
while (_Leftover < _Threshold) {
_Product = _Unsigned128{_Base128::_UMul128(_Ref(), _Bx, _Product._Word[1])};
_Leftover = _Product._Word[0];
}
}

return static_cast<_Diff>(_Product._Word[1]);
}

// Generate two bounded random values from a single 64-bit random word.
// The bounds are (n+1) and n for Fisher-Yates shuffle positions _Target_index and _Target_index-1.
void _Batch_2(_Diff (&_Results)[2], const _Diff _Bound1, const _Diff _Bound2) {
const uint64_t _Bx1 = static_cast<uint64_t>(_Bound1);
const uint64_t _Bx2 = static_cast<uint64_t>(_Bound2);
const uint64_t _Product_bound = _Bx1 * _Bx2;

uint64_t _Random_word = _Ref();

_Unsigned128 _Prod1{_Base128::_UMul128(_Random_word, _Bx1, _Prod1._Word[1])};
_Results[0] = static_cast<_Diff>(_Prod1._Word[1]);
uint64_t _Leftover1 = _Prod1._Word[0];

_Unsigned128 _Prod2{_Base128::_UMul128(_Leftover1, _Bx2, _Prod2._Word[1])};
_Results[1] = static_cast<_Diff>(_Prod2._Word[1]);
uint64_t _Leftover = _Prod2._Word[0];

// Rejection sampling: check if leftover is below threshold.
if (_Leftover < _Product_bound) {
const uint64_t _Threshold = (0 - _Product_bound) % _Product_bound;
while (_Leftover < _Threshold) {
_Random_word = _Ref();

_Prod1 = _Unsigned128{_Base128::_UMul128(_Random_word, _Bx1, _Prod1._Word[1])};
_Results[0] = static_cast<_Diff>(_Prod1._Word[1]);
_Leftover1 = _Prod1._Word[0];

_Prod2 = _Unsigned128{_Base128::_UMul128(_Leftover1, _Bx2, _Prod2._Word[1])};
_Results[1] = static_cast<_Diff>(_Prod2._Word[1]);
_Leftover = _Prod2._Word[0];
}
}
}

_Batched_rng_from_urng(const _Batched_rng_from_urng&) = delete;
_Batched_rng_from_urng& operator=(const _Batched_rng_from_urng&) = delete;
};

template <class _Diff, class _Urng>
class _Rng_from_urng_v2 { // wrap a URNG as an RNG
public:
Expand Down Expand Up @@ -6521,11 +6602,91 @@ void _Random_shuffle1(_RanIt _First, _RanIt _Last, _RngFn& _RngFunc) {
}
}

// Batched shuffle implementation for 64-bit URNGs with full range.
// Uses batched random generation to reduce URNG calls.
template <class _RanIt, class _Urng>
void _Random_shuffle_batched(_RanIt _First, _RanIt _Last, _Urng& _Func) {
// shuffle [_First, _Last) using batched random generation
_STD _Adl_verify_range(_First, _Last);
const auto _UFirst = _STD _Get_unwrapped(_First);
const auto _ULast = _STD _Get_unwrapped(_Last);
if (_UFirst == _ULast) {
return;
}

using _Diff = _Iter_diff_t<_RanIt>;
_Batched_rng_from_urng<_Diff, _Urng> _BatchedRng(_Func);

auto _UTarget = _UFirst;
_Diff _Target_index = 1;

// Process pairs using batched generation when beneficial.
// Batch of 2 is beneficial when bounds fit in 32 bits (product fits in 64 bits).
while (_UTarget != _ULast) {
++_UTarget;
if (_UTarget == _ULast) {
break;
}

const _Diff _Bound1 = _Target_index + 1; // bound for current position
const _Diff _Bound2 = _Target_index + 2; // bound for next position

// Check if we can batch: both bounds and their product must fit safely.
// Use batch of 2 when the larger bound is <= 2^32 (product fits in 64 bits).
if (static_cast<uint64_t>(_Bound2) <= _Batched_rng_from_urng<_Diff, _Urng>::_Bound_for_batch_2) {
auto _UTarget_next = _UTarget;
++_UTarget_next;

if (_UTarget_next != _ULast) {
// Generate two random indices in one batch.
_Diff _Offsets[2];
_BatchedRng._Batch_2(_Offsets, _Bound1, _Bound2);

_STL_ASSERT(0 <= _Offsets[0] && _Offsets[0] <= _Target_index, "random value out of range");
_STL_ASSERT(0 <= _Offsets[1] && _Offsets[1] <= _Target_index + 1, "random value out of range");

// Perform first swap.
if (_Offsets[0] != _Target_index) {
swap(*_UTarget, *(_UFirst + _Offsets[0])); // intentional ADL
}

// Advance to next position and perform second swap.
++_UTarget;
++_Target_index;

if (_Offsets[1] != _Target_index) {
swap(*_UTarget, *(_UFirst + _Offsets[1])); // intentional ADL
}

++_UTarget;
++_Target_index;
continue;
}
}

// Fall back to single generation for this position.
const _Diff _Off = _BatchedRng._Single_bounded(_Bound1);
_STL_ASSERT(0 <= _Off && _Off <= _Target_index, "random value out of range");
if (_Off != _Target_index) {
swap(*_UTarget, *(_UFirst + _Off)); // intentional ADL
}

++_UTarget;
++_Target_index;
}
}

_EXPORT_STD template <class _RanIt, class _Urng>
void shuffle(_RanIt _First, _RanIt _Last, _Urng&& _Func) { // shuffle [_First, _Last) using URNG _Func
using _Urng0 = remove_reference_t<_Urng>;
_Rng_from_urng_v2<_Iter_diff_t<_RanIt>, _Urng0> _RngFunc(_Func);
_STD _Random_shuffle1(_First, _Last, _RngFunc);

// Use batched shuffle when the URNG produces full 64-bit range values.
if constexpr (_Urng_has_full_64bit_range<_Urng0>) {
_STD _Random_shuffle_batched(_First, _Last, _Func);
} else {
_Rng_from_urng_v2<_Iter_diff_t<_RanIt>, _Urng0> _RngFunc(_Func);
_STD _Random_shuffle1(_First, _Last, _RngFunc);
}
}

#if _HAS_CXX20
Expand All @@ -6537,20 +6698,37 @@ namespace ranges {
_STATIC_CALL_OPERATOR _It operator()(_It _First, _Se _Last, _Urng&& _Func) _CONST_CALL_OPERATOR {
_STD _Adl_verify_range(_First, _Last);

_Rng_from_urng_v2<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
auto _UResult = _Shuffle_unchecked(
_RANGES _Unwrap_iter<_Se>(_STD move(_First)), _RANGES _Unwrap_sent<_It>(_STD move(_Last)), _RngFunc);
using _Urng0 = remove_reference_t<_Urng>;

_STD _Seek_wrapped(_First, _STD move(_UResult));
// Use batched shuffle when the URNG produces full 64-bit range values.
if constexpr (_Urng_has_full_64bit_range<_Urng0>) {
auto _UResult = _Shuffle_unchecked_batched(
_RANGES _Unwrap_iter<_Se>(_STD move(_First)), _RANGES _Unwrap_sent<_It>(_STD move(_Last)), _Func);
_STD _Seek_wrapped(_First, _STD move(_UResult));
} else {
using _Diff = iter_difference_t<_It>;
_Rng_from_urng_v2<_Diff, _Urng0> _RngFunc(_Func);
auto _UResult = _Shuffle_unchecked(_RANGES _Unwrap_iter<_Se>(_STD move(_First)),
_RANGES _Unwrap_sent<_It>(_STD move(_Last)), _RngFunc);
_STD _Seek_wrapped(_First, _STD move(_UResult));
}
return _First;
}

template <random_access_range _Rng, class _Urng>
requires permutable<iterator_t<_Rng>> && uniform_random_bit_generator<remove_reference_t<_Urng>>
_STATIC_CALL_OPERATOR borrowed_iterator_t<_Rng> operator()(_Rng&& _Range, _Urng&& _Func) _CONST_CALL_OPERATOR {
_Rng_from_urng_v2<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
using _Urng0 = remove_reference_t<_Urng>;

return _RANGES _Rewrap_iterator(_Range, _Shuffle_unchecked(_Ubegin(_Range), _Uend(_Range), _RngFunc));
// Use batched shuffle when the URNG produces full 64-bit range values.
if constexpr (_Urng_has_full_64bit_range<_Urng0>) {
return _RANGES _Rewrap_iterator(
_Range, _Shuffle_unchecked_batched(_Ubegin(_Range), _Uend(_Range), _Func));
} else {
using _Diff = range_difference_t<_Rng>;
_Rng_from_urng_v2<_Diff, _Urng0> _RngFunc(_Func);
return _RANGES _Rewrap_iterator(_Range, _Shuffle_unchecked(_Ubegin(_Range), _Uend(_Range), _RngFunc));
}
}

private:
Expand Down Expand Up @@ -6578,6 +6756,74 @@ namespace ranges {
}
return _Target;
}

// Batched shuffle implementation for ranges.
template <class _It, class _Se, class _Urng>
_NODISCARD static _It _Shuffle_unchecked_batched(_It _First, const _Se _Last, _Urng& _Func) {
// shuffle [_First, _Last) using batched random generation
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>);
_STL_INTERNAL_STATIC_ASSERT(permutable<_It>);

if (_First == _Last) {
return _First;
}

using _Diff = iter_difference_t<_It>;
_Batched_rng_from_urng<_Diff, _Urng> _BatchedRng(_Func);

auto _Target = _First;
_Diff _Target_index = 1;

// Process pairs using batched generation when beneficial.
while (_Target != _Last) {
++_Target;
if (_Target == _Last) {
break;
}

const _Diff _Bound1 = _Target_index + 1;
const _Diff _Bound2 = _Target_index + 2;

if (static_cast<uint64_t>(_Bound2) <= _Batched_rng_from_urng<_Diff, _Urng>::_Bound_for_batch_2) {
auto _Target_next = _Target;
++_Target_next;

if (_Target_next != _Last) {
_Diff _Offsets[2];
_BatchedRng._Batch_2(_Offsets, _Bound1, _Bound2);

_STL_ASSERT(0 <= _Offsets[0] && _Offsets[0] <= _Target_index, "random value out of range");
_STL_ASSERT(0 <= _Offsets[1] && _Offsets[1] <= _Target_index + 1, "random value out of range");

if (_Offsets[0] != _Target_index) {
_RANGES iter_swap(_Target, _First + _Offsets[0]);
}

++_Target;
++_Target_index;

if (_Offsets[1] != _Target_index) {
_RANGES iter_swap(_Target, _First + _Offsets[1]);
}

++_Target;
++_Target_index;
continue;
}
}

const _Diff _Off = _BatchedRng._Single_bounded(_Bound1);
_STL_ASSERT(0 <= _Off && _Off <= _Target_index, "random value out of range");
if (_Off != _Target_index) {
_RANGES iter_swap(_Target, _First + _Off);
}

++_Target;
++_Target_index;
}
return _Target;
}
};

_EXPORT_STD inline constexpr _Shuffle_fn shuffle;
Expand Down
70 changes: 70 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_shuffle/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <numeric>
#include <random>
#include <ranges>
#include <utility>
#include <vector>

#include <range_algorithm_support.hpp>
using namespace std;

const unsigned int seed = random_device{}();
mt19937 gen{seed};
mt19937_64 gen64{seed}; // 64-bit generator for batched random path

// Validate dangling story
static_assert(same_as<decltype(ranges::shuffle(borrowed<false>{}, gen)), ranges::dangling>);
Expand Down Expand Up @@ -72,8 +76,74 @@ void test_urbg() { // COMPILE-ONLY
ranges::shuffle(arr, RandGen{});
}

// Test that shuffle produces a valid permutation for various sizes.
// This exercises both the batched path (for 64-bit RNGs) and the fallback path.
void test_shuffle_permutation() {
const vector<int> original = [] {
vector<int> ret(100);
iota(ret.begin(), ret.end(), 0);
return ret;
}();

// Test with 64-bit generator (batched random path)
{
vector<int> v = original;
shuffle(v.begin(), v.end(), gen64);
sort(v.begin(), v.end());
assert(v == original); // Verify it's still a permutation
}

// Test with ranges::shuffle and 64-bit generator (batched random path)
{
vector<int> v = original;
ranges::shuffle(v, gen64);
sort(v.begin(), v.end());
assert(v == original); // Verify it's still a permutation
}

// Test with 32-bit generator (non-batched path)
{
vector<int> v = original;
shuffle(v.begin(), v.end(), gen);
sort(v.begin(), v.end());
assert(v == original); // Verify it's still a permutation
}

// Test with ranges::shuffle and 32-bit generator (non-batched path)
{
vector<int> v = original;
ranges::shuffle(v, gen);
sort(v.begin(), v.end());
assert(v == original); // Verify it's still a permutation
}
}

[[nodiscard]] bool shuffle_is_a_permutation(const size_t n) {
vector<int> v(n);
iota(v.begin(), v.end(), 0);
const vector<int> original = v;
shuffle(v.begin(), v.end(), gen64);
sort(v.begin(), v.end());
return v == original; // have the caller assert() for clearer diagnostics
}

// Test edge cases for shuffle
void test_shuffle_edge_cases() {
// Test both even and odd sizes to exercise the batching boundary.
// Test large sizes to ensure batching is effective.
assert(shuffle_is_a_permutation(0));
assert(shuffle_is_a_permutation(1));
assert(shuffle_is_a_permutation(2));
assert(shuffle_is_a_permutation(3));
assert(shuffle_is_a_permutation(4));
assert(shuffle_is_a_permutation(1729));
assert(shuffle_is_a_permutation(10000));
}

int main() {
printf("Using seed: %u\n", seed);

test_random<instantiator, int>();
test_shuffle_permutation();
test_shuffle_edge_cases();
}