From cc5b580bc90856d3e3f0b3232045eaf6afde0310 Mon Sep 17 00:00:00 2001 From: Hari Limaye Date: Tue, 17 Feb 2026 10:55:02 +0000 Subject: [PATCH] Add Neon implementation of `replace` --- stl/inc/xutility | 2 +- stl/src/vector_algorithms.cpp | 133 +++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 4 deletions(-) diff --git a/stl/inc/xutility b/stl/inc/xutility index 2d5bd7203c..cab17bcb55 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -88,7 +88,7 @@ _STL_DISABLE_CLANG_WARNINGS #define _VECTORIZED_REMOVE _VECTORIZED_FOR_X64_X86 #define _VECTORIZED_REMOVE_COPY _VECTORIZED_FOR_X64_X86 #define _VECTORIZED_REPLACE _VECTORIZED_FOR_X64_X86 -#define _VECTORIZED_REPLACE_COPY _VECTORIZED_FOR_X64_X86 +#define _VECTORIZED_REPLACE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REVERSE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REVERSE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_ROTATE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index e4b29de2a5..0b8900b563 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -4029,6 +4029,14 @@ namespace { return vld1_u8(static_cast(_Ptr)); } + static void _Store_q(void* const _Ptr, const uint8x16_t _Val) noexcept { + vst1q_u8(static_cast(_Ptr), _Val); + } + + static void _Store(void* const _Ptr, const uint8x8_t _Val) noexcept { + vst1_u8(static_cast(_Ptr), _Val); + } + static uint8x16_t _Set_neon_q(const uint8_t _Val) noexcept { return vdupq_n_u8(_Val); } @@ -4065,6 +4073,14 @@ namespace { const auto _Comb = vreinterpretq_u64_u8(vpminq_u8(_Cmp, _Cmp)); return vgetq_lane_u64(_Comb, 0) ^ 0xFFFF'FFFF'FFFF'FFFF; } + + static uint8x16_t _Blend_q(const uint8x16_t _Px1, const uint8x16_t _Px2, const uint8x16_t _Msk) noexcept { + return vbslq_u8(_Msk, _Px2, _Px1); + } + + static uint8x8_t _Blend(const uint8x8_t _Px1, const uint8x8_t _Px2, const uint8x8_t _Msk) noexcept { + return vbsl_u8(_Msk, _Px2, _Px1); + } }; struct _Find_traits_2 { @@ -4076,6 +4092,14 @@ namespace { return vld1_u16(static_cast(_Ptr)); } + static void _Store_q(void* const _Ptr, const uint16x8_t _Val) noexcept { + vst1q_u16(static_cast(_Ptr), _Val); + } + + static void _Store(void* const _Ptr, const uint16x4_t _Val) noexcept { + vst1_u16(static_cast(_Ptr), _Val); + } + static uint16x8_t _Set_neon_q(const uint16_t _Val) noexcept { return vdupq_n_u16(_Val); } @@ -4112,6 +4136,14 @@ namespace { const auto _Comb = vreinterpretq_u64_u16(vpminq_u16(_Cmp, _Cmp)); return vgetq_lane_u64(_Comb, 0) ^ 0xFFFF'FFFF'FFFF'FFFF; } + + static uint16x8_t _Blend_q(const uint16x8_t _Px1, const uint16x8_t _Px2, const uint16x8_t _Msk) noexcept { + return vbslq_u16(_Msk, _Px2, _Px1); + } + + static uint16x4_t _Blend(const uint16x4_t _Px1, const uint16x4_t _Px2, const uint16x4_t _Msk) noexcept { + return vbsl_u16(_Msk, _Px2, _Px1); + } }; struct _Find_traits_4 { @@ -4123,6 +4155,14 @@ namespace { return vld1_u32(static_cast(_Ptr)); } + static void _Store_q(void* const _Ptr, const uint32x4_t _Val) noexcept { + vst1q_u32(static_cast(_Ptr), _Val); + } + + static void _Store(void* const _Ptr, const uint32x2_t _Val) noexcept { + vst1_u32(static_cast(_Ptr), _Val); + } + static uint32x4_t _Set_neon_q(const uint32_t _Val) noexcept { return vdupq_n_u32(_Val); } @@ -4159,6 +4199,14 @@ namespace { const auto _Comb = vreinterpretq_u64_u32(vpminq_u32(_Cmp, _Cmp)); return vgetq_lane_u64(_Comb, 0) ^ 0xFFFF'FFFF'FFFF'FFFF; } + + static uint32x4_t _Blend_q(const uint32x4_t _Px1, const uint32x4_t _Px2, const uint32x4_t _Msk) noexcept { + return vbslq_u32(_Msk, _Px2, _Px1); + } + + static uint32x2_t _Blend(const uint32x2_t _Px1, const uint32x2_t _Px2, const uint32x2_t _Msk) noexcept { + return vbsl_u32(_Msk, _Px2, _Px1); + } }; struct _Find_traits_8 { @@ -4166,6 +4214,10 @@ namespace { return vld1q_u64(static_cast(_Ptr)); } + static void _Store_q(void* const _Ptr, const uint64x2_t _Val) noexcept { + vst1q_u64(static_cast(_Ptr), _Val); + } + static uint64x2_t _Set_neon_q(const uint64_t _Val) noexcept { return vdupq_n_u64(_Val); } @@ -4188,6 +4240,10 @@ namespace { static uint64_t _Match_mask_ne(const uint64x2_t _Cmp_lo, const uint64x2_t _Cmp_hi) noexcept { return _Mask_q(vandq_u64(_Cmp_lo, _Cmp_hi)) ^ 0xFFFF'FFFF'FFFF'FFFF; } + + static uint64x2_t _Blend_q(const uint64x2_t _Px1, const uint64x2_t _Px2, const uint64x2_t _Msk) noexcept { + return vbslq_u64(_Msk, _Px2, _Px1); + } }; unsigned long _Get_first_h_pos_q(const uint64_t _Mask) noexcept { @@ -8138,13 +8194,81 @@ __declspec(noalias) size_t __stdcall __std_mismatch_8( } // extern "C" -#ifndef _M_ARM64 namespace { namespace _Replacing { +#if defined(_M_ARM64) || defined(_M_ARM64EC) + template + __declspec(noalias) void __stdcall _Replace_copy_impl( + const void* _First, const void* const _Last, void* _Dest, const _Ty _Old_val, const _Ty _New_val) noexcept { + const size_t _Size_bytes = _Byte_length(_First, _Last); + + if (const size_t _Size = _Size_bytes & ~size_t{0x1F}; _Size != 0) { + const auto _Comparand = _Traits::_Set_neon_q(_Old_val); + const auto _Replacement = _Traits::_Set_neon_q(_New_val); + const void* _Stop_at = _First; + _Advance_bytes(_Stop_at, _Size); + + do { + const auto _Data_lo = _Traits::_Load_q(static_cast(_First) + 0); + const auto _Data_hi = _Traits::_Load_q(static_cast(_First) + 16); + + const auto _Mask_lo = _Traits::_Cmp_neon_q(_Data_lo, _Comparand); + const auto _Mask_hi = _Traits::_Cmp_neon_q(_Data_hi, _Comparand); + + const auto _Val_lo = _Traits::_Blend_q(_Data_lo, _Replacement, _Mask_lo); + const auto _Val_hi = _Traits::_Blend_q(_Data_hi, _Replacement, _Mask_hi); + + _Traits::_Store_q(static_cast(_Dest) + 0, _Val_lo); + _Traits::_Store_q(static_cast(_Dest) + 16, _Val_hi); + + _Advance_bytes(_First, 32); + _Advance_bytes(_Dest, 32); + } while (_First != _Stop_at); + } + + if ((_Size_bytes & size_t{0x10}) != 0) { // use original _Size_bytes; we've read only 32-byte chunks + const auto _Comparand = _Traits::_Set_neon_q(_Old_val); + const auto _Replacement = _Traits::_Set_neon_q(_New_val); + + const auto _Data = _Traits::_Load_q(_First); + const auto _Mask = _Traits::_Cmp_neon_q(_Data, _Comparand); + const auto _Val = _Traits::_Blend_q(_Data, _Replacement, _Mask); + + _Traits::_Store_q(_Dest, _Val); + + _Advance_bytes(_First, 16); + _Advance_bytes(_Dest, 16); + } + + if constexpr (sizeof(_Ty) < 8) { + if ((_Size_bytes & size_t{0x08}) != 0) { // use original _Size_bytes; we've read only 16/32-byte chunks + const auto _Comparand = _Traits::_Set_neon(_Old_val); + const auto _Replacement = _Traits::_Set_neon(_New_val); + + const auto _Data = _Traits::_Load(_First); + const auto _Mask = _Traits::_Cmp_neon(_Data, _Comparand); + const auto _Val = _Traits::_Blend(_Data, _Replacement, _Mask); + + _Traits::_Store(_Dest, _Val); + + _Advance_bytes(_First, 8); + _Advance_bytes(_Dest, 8); + } + } + + auto _Ptr_dest = static_cast<_Ty*>(_Dest); +// Avoid auto-vectorization of the scalar tail, as this is not beneficial for performance. +#pragma loop(no_vector) + for (auto _Ptr_src = static_cast(_First); _Ptr_src != _Last; ++_Ptr_src) { + const _Ty _Val = *_Ptr_src; + *_Ptr_dest = _Val == _Old_val ? _New_val : _Val; + ++_Ptr_dest; + } + } +#else // ^^^ defined(_M_ARM64) || defined(_M_ARM64EC) / !defined(_M_ARM64) && !defined(_M_ARM64EC) vvv template __declspec(noalias) void __stdcall _Replace_copy_impl( const void* _First, const void* const _Last, void* _Dest, const _Ty _Old_val, const _Ty _New_val) noexcept { -#ifndef _M_ARM64EC const size_t _Size_bytes = _Byte_length(_First, _Last); if (const size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) { @@ -8198,7 +8322,6 @@ namespace { _Advance_bytes(_Dest, 16); } while (_First != _Stop_at); } -#endif // ^^^ !defined(_M_ARM64EC) ^^^ auto _Ptr_dest = static_cast<_Ty*>(_Dest); for (auto _Ptr_src = static_cast(_First); _Ptr_src != _Last; ++_Ptr_src) { const _Ty _Val = *_Ptr_src; @@ -8206,11 +8329,13 @@ namespace { ++_Ptr_dest; } } +#endif // ^^^ !defined(_M_ARM64) && !defined(_M_ARM64EC) ^^^ } // namespace _Replacing } // unnamed namespace extern "C" { +#ifndef _M_ARM64 __declspec(noalias) void __stdcall __std_replace_4( void* _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept { #ifndef _M_ARM64EC @@ -8291,6 +8416,7 @@ __declspec(noalias) void __stdcall __std_replace_8( } } } +#endif // ^^^ !defined(_M_ARM64) ^^^ __declspec(noalias) void __stdcall __std_replace_copy_1(const void* const _First, const void* const _Last, void* const _Dest, const uint8_t _Old_val, const uint8_t _New_val) noexcept { @@ -8314,6 +8440,7 @@ __declspec(noalias) void __stdcall __std_replace_copy_8(const void* const _First } // extern "C" +#ifndef _M_ARM64 namespace { namespace _Removing { // 'remove' and 'unique': form bit mask based on matches, then do _mm_shuffle_epi8/_mm256_permutevar8x32_epi32