From 67be7536d0996d3ccd5b68ad3e79e5f0ef5272e5 Mon Sep 17 00:00:00 2001 From: immanuelazn Date: Mon, 25 Nov 2024 12:08:33 -0800 Subject: [PATCH] [cpp] update hwy to v1.2.0 --- r/src/vendor/highway/README.md | 2 +- r/src/vendor/highway/hwy/abort.cc | 117 + r/src/vendor/highway/hwy/abort.h | 44 + r/src/vendor/highway/hwy/aligned_allocator.cc | 23 +- r/src/vendor/highway/hwy/aligned_allocator.h | 236 +- r/src/vendor/highway/hwy/base.h | 2598 +++-- r/src/vendor/highway/hwy/bit_set.h | 158 + r/src/vendor/highway/hwy/cache_control.h | 30 +- .../highway/hwy/contrib/algo/copy-inl.h | 28 +- .../highway/hwy/contrib/algo/copy_test.cc | 18 +- .../highway/hwy/contrib/algo/find-inl.h | 19 +- .../highway/hwy/contrib/algo/find_test.cc | 34 +- .../highway/hwy/contrib/algo/transform-inl.h | 188 +- .../hwy/contrib/algo/transform_test.cc | 164 +- .../highway/hwy/contrib/math/math-inl.h | 279 +- .../highway/hwy/contrib/math/math_test.cc | 358 +- .../vendor/highway/hwy/detect_compiler_arch.h | 115 +- r/src/vendor/highway/hwy/detect_targets.h | 247 +- r/src/vendor/highway/hwy/foreach_target.h | 37 +- r/src/vendor/highway/hwy/highway.h | 272 +- r/src/vendor/highway/hwy/nanobenchmark.cc | 24 +- r/src/vendor/highway/hwy/nanobenchmark.h | 20 +- r/src/vendor/highway/hwy/ops/arm_neon-inl.h | 3276 +++++-- r/src/vendor/highway/hwy/ops/arm_sve-inl.h | 1923 +++- r/src/vendor/highway/hwy/ops/emu128-inl.h | 1405 +-- .../vendor/highway/hwy/ops/generic_ops-inl.h | 6937 +++++++++----- r/src/vendor/highway/hwy/ops/inside-inl.h | 691 ++ r/src/vendor/highway/hwy/ops/ppc_vsx-inl.h | 2898 ++++-- r/src/vendor/highway/hwy/ops/rvv-inl.h | 2308 +++-- r/src/vendor/highway/hwy/ops/scalar-inl.h | 646 +- r/src/vendor/highway/hwy/ops/set_macros-inl.h | 201 +- r/src/vendor/highway/hwy/ops/shared-inl.h | 267 +- r/src/vendor/highway/hwy/ops/tuple-inl.h | 125 - r/src/vendor/highway/hwy/ops/wasm_128-inl.h | 575 +- r/src/vendor/highway/hwy/ops/wasm_256-inl.h | 230 +- r/src/vendor/highway/hwy/ops/x86_128-inl.h | 8389 +++++++++++------ r/src/vendor/highway/hwy/ops/x86_256-inl.h | 2456 +++-- r/src/vendor/highway/hwy/ops/x86_512-inl.h | 2327 ++++- r/src/vendor/highway/hwy/per_target.cc | 19 + r/src/vendor/highway/hwy/per_target.h | 7 +- r/src/vendor/highway/hwy/print.cc | 38 +- r/src/vendor/highway/hwy/profiler.h | 682 ++ r/src/vendor/highway/hwy/robust_statistics.h | 4 +- r/src/vendor/highway/hwy/stats.cc | 120 + r/src/vendor/highway/hwy/stats.h | 194 + r/src/vendor/highway/hwy/targets.cc | 367 +- r/src/vendor/highway/hwy/targets.h | 72 +- r/src/vendor/highway/hwy/timer-inl.h | 8 +- r/src/vendor/highway/hwy/timer.cc | 36 +- r/src/vendor/highway/hwy/timer.h | 17 +- .../highway/manual-build/build_highway.sh | 13 +- 51 files changed, 30107 insertions(+), 11135 deletions(-) create mode 100644 r/src/vendor/highway/hwy/abort.cc create mode 100644 r/src/vendor/highway/hwy/abort.h create mode 100644 r/src/vendor/highway/hwy/bit_set.h create mode 100644 r/src/vendor/highway/hwy/ops/inside-inl.h delete mode 100644 r/src/vendor/highway/hwy/ops/tuple-inl.h create mode 100644 r/src/vendor/highway/hwy/profiler.h create mode 100644 r/src/vendor/highway/hwy/stats.cc create mode 100644 r/src/vendor/highway/hwy/stats.h diff --git a/r/src/vendor/highway/README.md b/r/src/vendor/highway/README.md index 948c713e..07899c1a 100644 --- a/r/src/vendor/highway/README.md +++ b/r/src/vendor/highway/README.md @@ -6,7 +6,7 @@ To prep the source tree starting from the highway root dir, run: ```bash mkdir -p lib-copy cp -r hwy lib-copy -rm -r lib-copy/hwy/{examples,tests,*_test.cc} lib-copy/hwy/contrib/{bit_pack,dot,image,sort,unroller} +rm -r lib-copy/hwy/{examples,tests,*_test.cc} lib-copy/hwy/contrib/{bit_pack,dot,image,sort,unroller,random,thread_pool,matvec} ``` The files in `manual-build` are custom scripts for manually building the library without a cmake dependency diff --git a/r/src/vendor/highway/hwy/abort.cc b/r/src/vendor/highway/hwy/abort.cc new file mode 100644 index 00000000..a67819bb --- /dev/null +++ b/r/src/vendor/highway/hwy/abort.cc @@ -0,0 +1,117 @@ +// Copyright 2019 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause + +#include "hwy/abort.h" + +#include +#include +#include + +#include +#include + +#include "hwy/base.h" + +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif + +namespace hwy { + +namespace { + +std::atomic& AtomicWarnFunc() { + static std::atomic func; + return func; +} + +std::atomic& AtomicAbortFunc() { + static std::atomic func; + return func; +} + +std::string GetBaseName(std::string const& file_name) { + auto last_slash = file_name.find_last_of("/\\"); + return file_name.substr(last_slash + 1); +} + +} // namespace + +// Returning a reference is unfortunately incompatible with `std::atomic`, which +// is required to safely implement `SetWarnFunc`. As a workaround, we store a +// copy here, update it when called, and return a reference to the copy. This +// has the added benefit of protecting the actual pointer from modification. +HWY_DLLEXPORT WarnFunc& GetWarnFunc() { + static WarnFunc func; + func = AtomicWarnFunc().load(); + return func; +} + +HWY_DLLEXPORT AbortFunc& GetAbortFunc() { + static AbortFunc func; + func = AtomicAbortFunc().load(); + return func; +} + +HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func) { + return AtomicWarnFunc().exchange(func); +} + +HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func) { + return AtomicAbortFunc().exchange(func); +} + +HWY_DLLEXPORT void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + WarnFunc handler = AtomicWarnFunc().load(); + if (handler != nullptr) { + handler(file, line, buf); + } else { + fprintf(stderr, "Warn at %s:%d: %s\n", GetBaseName(file).data(), line, buf); + } +} + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + AbortFunc handler = AtomicAbortFunc().load(); + if (handler != nullptr) { + handler(file, line, buf); + } else { + fprintf(stderr, "Abort at %s:%d: %s\n", GetBaseName(file).data(), line, + buf); + } + +// If compiled with any sanitizer, they can also print a stack trace. +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + __sanitizer_print_stack_trace(); +#endif // HWY_IS_* + fflush(stderr); + +// Now terminate the program: +#if HWY_ARCH_RISCV + exit(1); // trap/abort just freeze Spike. +#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC && !HWY_ARCH_ARM + // Facilitates breaking into a debugger, but don't use this in non-debug + // builds because it looks like "illegal instruction", which is misleading. + // Also does not work on Arm. + __builtin_trap(); +#else + abort(); // Compile error without this due to HWY_NORETURN. +#endif +} + +} // namespace hwy diff --git a/r/src/vendor/highway/hwy/abort.h b/r/src/vendor/highway/hwy/abort.h new file mode 100644 index 00000000..afa68fb4 --- /dev/null +++ b/r/src/vendor/highway/hwy/abort.h @@ -0,0 +1,44 @@ +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef HIGHWAY_HWY_ABORT_H_ +#define HIGHWAY_HWY_ABORT_H_ + +#include "hwy/highway_export.h" + +namespace hwy { + +// Interfaces for custom Warn/Abort handlers. +typedef void (*WarnFunc)(const char* file, int line, const char* message); + +typedef void (*AbortFunc)(const char* file, int line, const char* message); + +// Returns current Warn() handler, or nullptr if no handler was yet registered, +// indicating Highway should print to stderr. +// DEPRECATED because this is thread-hostile and prone to misuse (modifying the +// underlying pointer through the reference). +HWY_DLLEXPORT WarnFunc& GetWarnFunc(); + +// Returns current Abort() handler, or nullptr if no handler was yet registered, +// indicating Highway should print to stderr and abort. +// DEPRECATED because this is thread-hostile and prone to misuse (modifying the +// underlying pointer through the reference). +HWY_DLLEXPORT AbortFunc& GetAbortFunc(); + +// Sets a new Warn() handler and returns the previous handler, which is nullptr +// if no previous handler was registered, and should otherwise be called from +// the new handler. Thread-safe. +HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func); + +// Sets a new Abort() handler and returns the previous handler, which is nullptr +// if no previous handler was registered, and should otherwise be called from +// the new handler. If all handlers return, then Highway will terminate the app. +// Thread-safe. +HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func); + +// Abort()/Warn() and HWY_ABORT/HWY_WARN are declared in base.h. + +} // namespace hwy + +#endif // HIGHWAY_HWY_ABORT_H_ diff --git a/r/src/vendor/highway/hwy/aligned_allocator.cc b/r/src/vendor/highway/hwy/aligned_allocator.cc index e240a49e..e857b228 100644 --- a/r/src/vendor/highway/hwy/aligned_allocator.cc +++ b/r/src/vendor/highway/hwy/aligned_allocator.cc @@ -27,7 +27,8 @@ namespace hwy { namespace { -#if HWY_ARCH_RVV && defined(__riscv_v_intrinsic) && __riscv_v_intrinsic >= 11000 +#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 // Not actually an upper bound on the size, but this value prevents crossing a // 4K boundary (relevant on Andes). constexpr size_t kAlignment = HWY_MAX(HWY_ALIGNMENT, 4096); @@ -36,9 +37,11 @@ constexpr size_t kAlignment = HWY_ALIGNMENT; #endif #if HWY_ARCH_X86 -// On x86, aliasing can only occur at multiples of 2K, but that's too wasteful -// if this is used for single-vector allocations. 256 is more reasonable. -constexpr size_t kAlias = kAlignment * 4; +// On x86, aliasing can only occur at multiples of 2K. To reduce the chance of +// allocations being equal mod 2K, we round up to kAlias and add a cyclic +// offset which is a multiple of kAlignment. Rounding up to only 1K decreases +// the number of alias-free allocations, but also wastes less memory. +constexpr size_t kAlias = HWY_MAX(kAlignment, 1024); #else constexpr size_t kAlias = kAlignment; #endif @@ -52,9 +55,10 @@ struct AllocationHeader { // Returns a 'random' (cyclical) offset for AllocateAlignedBytes. size_t NextAlignedOffset() { - static std::atomic next{0}; - constexpr uint32_t kGroups = kAlias / kAlignment; - const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + static std::atomic next{0}; + static_assert(kAlias % kAlignment == 0, "kAlias must be a multiple"); + constexpr size_t kGroups = kAlias / kAlignment; + const size_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; const size_t offset = kAlignment * group; HWY_DASSERT((offset % kAlignment == 0) && offset <= kAlias); return offset; @@ -79,8 +83,7 @@ HWY_DLLEXPORT void* AllocateAlignedBytes(const size_t payload_size, // To avoid wasting space, the header resides at the end of `unused`, // which therefore cannot be empty (offset == 0). if (offset == 0) { - offset = kAlignment; // = RoundUpTo(sizeof(AllocationHeader), kAlignment) - static_assert(sizeof(AllocationHeader) <= kAlignment, "Else: round up"); + offset = RoundUpTo(sizeof(AllocationHeader), kAlignment); } const size_t allocated_size = kAlias + offset + payload_size; @@ -99,10 +102,12 @@ HWY_DLLEXPORT void* AllocateAlignedBytes(const size_t payload_size, aligned &= ~(kAlias - 1); const uintptr_t payload = aligned + offset; // still aligned + HWY_DASSERT(payload % kAlignment == 0); // Stash `allocated` and payload_size inside header for FreeAlignedBytes(). // The allocated_size can be reconstructed from the payload_size. AllocationHeader* header = reinterpret_cast(payload) - 1; + HWY_DASSERT(reinterpret_cast(header) >= aligned); header->allocated = allocated; header->payload_size = payload_size; diff --git a/r/src/vendor/highway/hwy/aligned_allocator.h b/r/src/vendor/highway/hwy/aligned_allocator.h index d0671a57..e738c8be 100644 --- a/r/src/vendor/highway/hwy/aligned_allocator.h +++ b/r/src/vendor/highway/hwy/aligned_allocator.h @@ -18,17 +18,32 @@ // Memory allocator with support for alignment and offsets. +#include +#include +#include +#include +#include +#include #include +#include #include +#include #include "hwy/base.h" +#include "hwy/per_target.h" namespace hwy { // Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which -// requires a literal. This matches typical L1 cache line sizes, which prevents -// false sharing. -#define HWY_ALIGNMENT 64 +// requires a literal. To prevent false sharing, this should be at least the +// L1 cache line size, usually 64 bytes. However, Intel's L2 prefetchers may +// access pairs of lines, and M1 L2 and POWER8 lines are also 128 bytes. +#define HWY_ALIGNMENT 128 + +template +HWY_API constexpr bool IsAligned(T* ptr, size_t align = HWY_ALIGNMENT) { + return reinterpret_cast(ptr) % align == 0; +} // Pointers to functions equivalent to malloc/free with an opaque void* passed // to them. @@ -40,7 +55,8 @@ using FreePtr = void (*)(void* opaque, void* memory); // the vector size. Calls `alloc` with the passed `opaque` pointer to obtain // memory or malloc() if it is null. HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size, - AllocPtr alloc_ptr, void* opaque_ptr); + AllocPtr alloc_ptr = nullptr, + void* opaque_ptr = nullptr); // Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it // must have been returned from a previous call to `AllocateAlignedBytes`. @@ -110,12 +126,51 @@ AlignedUniquePtr MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free, // functions. template AlignedUniquePtr MakeUniqueAligned(Args&&... args) { - T* ptr = static_cast(AllocateAlignedBytes( - sizeof(T), /*alloc_ptr=*/nullptr, /*opaque_ptr=*/nullptr)); + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T))); return AlignedUniquePtr(new (ptr) T(std::forward(args)...), AlignedDeleter()); } +template +struct AlignedAllocator { + using value_type = T; + + AlignedAllocator() = default; + + template + explicit AlignedAllocator(const AlignedAllocator&) noexcept {} + + template + value_type* allocate(V n) { + static_assert(std::is_integral::value, + "AlignedAllocator only supports integer types"); + static_assert(sizeof(V) <= sizeof(std::size_t), + "V n must be smaller or equal size_t to avoid overflow"); + return static_cast( + AllocateAlignedBytes(static_cast(n) * sizeof(value_type))); + } + + template + void deallocate(value_type* p, HWY_MAYBE_UNUSED V n) { + return FreeAlignedBytes(p, nullptr, nullptr); + } +}; + +template +constexpr bool operator==(const AlignedAllocator&, + const AlignedAllocator&) noexcept { + return true; +} + +template +constexpr bool operator!=(const AlignedAllocator&, + const AlignedAllocator&) noexcept { + return false; +} + +template +using AlignedVector = std::vector>; + // Helpers for array allocators (avoids overflow) namespace detail { @@ -126,14 +181,14 @@ static inline constexpr size_t ShiftCount(size_t n) { template T* AllocateAlignedItems(size_t items, AllocPtr alloc_ptr, void* opaque_ptr) { - constexpr size_t size = sizeof(T); + constexpr size_t kSize = sizeof(T); - constexpr bool is_pow2 = (size & (size - 1)) == 0; - constexpr size_t bits = ShiftCount(size); - static_assert(!is_pow2 || (1ull << bits) == size, "ShiftCount is incorrect"); + constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; + constexpr size_t kBits = ShiftCount(kSize); + static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); - const size_t bytes = is_pow2 ? items << bits : items * size; - const size_t check = is_pow2 ? bytes >> bits : bytes / size; + const size_t bytes = kIsPow2 ? items << kBits : items * kSize; + const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; if (check != items) { return nullptr; // overflowed } @@ -207,5 +262,162 @@ AlignedFreeUniquePtr AllocateAligned(const size_t items) { return AllocateAligned(items, nullptr, nullptr, nullptr); } +// A simple span containing data and size of data. +template +class Span { + public: + Span() = default; + Span(T* data, size_t size) : size_(size), data_(data) {} + template + Span(U u) : Span(u.data(), u.size()) {} + Span(std::initializer_list v) : Span(v.begin(), v.size()) {} + + // Copies the contents of the initializer list to the span. + Span& operator=(std::initializer_list v) { + HWY_DASSERT(size_ == v.size()); + CopyBytes(v.begin(), data_, sizeof(T) * std::min(size_, v.size())); + return *this; + } + + // Returns the size of the contained data. + size_t size() const { return size_; } + + // Returns a pointer to the contained data. + T* data() { return data_; } + T* data() const { return data_; } + + // Returns the element at index. + T& operator[](size_t index) const { return data_[index]; } + + // Returns an iterator pointing to the first element of this span. + T* begin() { return data_; } + + // Returns a const iterator pointing to the first element of this span. + constexpr const T* cbegin() const { return data_; } + + // Returns an iterator pointing just beyond the last element at the + // end of this span. + T* end() { return data_ + size_; } + + // Returns a const iterator pointing just beyond the last element at the + // end of this span. + constexpr const T* cend() const { return data_ + size_; } + + private: + size_t size_ = 0; + T* data_ = nullptr; +}; + +// A multi dimensional array containing an aligned buffer. +// +// To maintain alignment, the innermost dimension will be padded to ensure all +// innermost arrays are aligned. +template +class AlignedNDArray { + static_assert(std::is_trivial::value, + "AlignedNDArray can only contain trivial types"); + + public: + AlignedNDArray(AlignedNDArray&& other) = default; + AlignedNDArray& operator=(AlignedNDArray&& other) = default; + + // Constructs an array of the provided shape and fills it with zeros. + explicit AlignedNDArray(std::array shape) : shape_(shape) { + sizes_ = ComputeSizes(shape_); + memory_shape_ = shape_; + // Round the innermost dimension up to the number of bytes available for + // SIMD operations on this architecture to make sure that each innermost + // array is aligned from the first element. + memory_shape_[axes - 1] = RoundUpTo(memory_shape_[axes - 1], VectorBytes()); + memory_sizes_ = ComputeSizes(memory_shape_); + buffer_ = hwy::AllocateAligned(memory_size()); + hwy::ZeroBytes(buffer_.get(), memory_size() * sizeof(T)); + } + + // Returns a span containing the innermost array at the provided indices. + Span operator[](std::array indices) { + return Span(buffer_.get() + Offset(indices), sizes_[indices.size()]); + } + + // Returns a const span containing the innermost array at the provided + // indices. + Span operator[](std::array indices) const { + return Span(buffer_.get() + Offset(indices), + sizes_[indices.size()]); + } + + // Returns the shape of the array, which might be smaller than the allocated + // buffer after padding the last axis to alignment. + const std::array& shape() const { return shape_; } + + // Returns the shape of the allocated buffer, which might be larger than the + // used size of the array after padding to alignment. + const std::array& memory_shape() const { return memory_shape_; } + + // Returns the size of the array, which might be smaller than the allocated + // buffer after padding the last axis to alignment. + size_t size() const { return sizes_[0]; } + + // Returns the size of the allocated buffer, which might be larger than the + // used size of the array after padding to alignment. + size_t memory_size() const { return memory_sizes_[0]; } + + // Returns a pointer to the allocated buffer. + T* data() { return buffer_.get(); } + + // Returns a const pointer to the buffer. + const T* data() const { return buffer_.get(); } + + // Truncates the array by updating its shape. + // + // The new shape must be equal to or less than the old shape in all axes. + // + // Doesn't modify underlying memory. + void truncate(const std::array& new_shape) { +#if HWY_IS_DEBUG_BUILD + for (size_t axis_index = 0; axis_index < axes; ++axis_index) { + HWY_ASSERT(new_shape[axis_index] <= shape_[axis_index]); + } +#endif + shape_ = new_shape; + sizes_ = ComputeSizes(shape_); + } + + private: + std::array shape_; + std::array memory_shape_; + std::array sizes_; + std::array memory_sizes_; + hwy::AlignedFreeUniquePtr buffer_; + + // Computes offset in the buffer based on the provided indices. + size_t Offset(std::array indices) const { + size_t offset = 0; + size_t shape_index = 0; + for (const size_t axis_index : indices) { + offset += memory_sizes_[shape_index + 1] * axis_index; + shape_index++; + } + return offset; + } + + // Computes the sizes of all sub arrays based on the sizes of each axis. + // + // Does this by multiplying the size of each axis with the previous one in + // reverse order, starting with the conceptual axis of size 1 containing the + // actual elements in the array. + static std::array ComputeSizes( + std::array shape) { + std::array sizes; + size_t axis = shape.size(); + sizes[axis] = 1; + while (axis > 0) { + --axis; + sizes[axis] = sizes[axis + 1] * shape[axis]; + } + return sizes; + } +}; + } // namespace hwy #endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ diff --git a/r/src/vendor/highway/hwy/base.h b/r/src/vendor/highway/hwy/base.h index 9d74f2b7..f2dc87c0 100644 --- a/r/src/vendor/highway/hwy/base.h +++ b/r/src/vendor/highway/hwy/base.h @@ -16,22 +16,31 @@ #ifndef HIGHWAY_HWY_BASE_H_ #define HIGHWAY_HWY_BASE_H_ -// For SIMD module implementations and their callers, target-independent. +// Target-independent definitions. // IWYU pragma: begin_exports #include #include -// Wrapping this into a HWY_HAS_INCLUDE causes clang-format to fail. -#if __cplusplus >= 202100L && defined(__has_include) -#if __has_include() -#include // std::float16_t -#endif +#if !defined(HWY_NO_LIBCXX) +#include #endif #include "hwy/detect_compiler_arch.h" #include "hwy/highway_export.h" +// API version (https://semver.org/); keep in sync with CMakeLists.txt. +#define HWY_MAJOR 1 +#define HWY_MINOR 2 +#define HWY_PATCH 0 + +// True if the Highway version >= major.minor.0. Added in 1.2.0. +#define HWY_VERSION_GE(major, minor) \ + (HWY_MAJOR > (major) || (HWY_MAJOR == (major) && HWY_MINOR >= (minor))) +// True if the Highway version < major.minor.0. Added in 1.2.0. +#define HWY_VERSION_LT(major, minor) \ + (HWY_MAJOR < (major) || (HWY_MAJOR == (major) && HWY_MINOR < (minor))) + // "IWYU pragma: keep" does not work for these includes, so hide from the IDE. #if !HWY_IDE @@ -48,6 +57,26 @@ #endif // !HWY_IDE +#ifndef HWY_HAVE_COMPARE_HEADER // allow override +#define HWY_HAVE_COMPARE_HEADER 0 +#if defined(__has_include) // note: wrapper macro fails on Clang ~17 +#if __has_include() +#undef HWY_HAVE_COMPARE_HEADER +#define HWY_HAVE_COMPARE_HEADER 1 +#endif // __has_include +#endif // defined(__has_include) +#endif // HWY_HAVE_COMPARE_HEADER + +#ifndef HWY_HAVE_CXX20_THREE_WAY_COMPARE // allow override +#if !defined(HWY_NO_LIBCXX) && defined(__cpp_impl_three_way_comparison) && \ + __cpp_impl_three_way_comparison >= 201907L && HWY_HAVE_COMPARE_HEADER +#include +#define HWY_HAVE_CXX20_THREE_WAY_COMPARE 1 +#else +#define HWY_HAVE_CXX20_THREE_WAY_COMPARE 0 +#endif +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + // IWYU pragma: end_exports #if HWY_COMPILER_MSVC @@ -64,6 +93,7 @@ #include +#define HWY_FUNCTION __FUNCSIG__ // function name + template args #define HWY_RESTRICT __restrict #define HWY_INLINE __forceinline #define HWY_NOINLINE __declspec(noinline) @@ -84,6 +114,7 @@ #else +#define HWY_FUNCTION __PRETTY_FUNCTION__ // function name + template args #define HWY_RESTRICT __restrict__ // force inlining without optimization enabled creates very inefficient code // that can cause compiler timeout @@ -131,6 +162,12 @@ namespace hwy { #define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ #endif +// Returns a pointer whose type is `type` (T*), while allowing the compiler to +// assume that the untyped pointer `ptr` is aligned to a multiple of sizeof(T). +#define HWY_RCAST_ALIGNED(type, ptr) \ + reinterpret_cast( \ + HWY_ASSUME_ALIGNED((ptr), alignof(hwy::RemovePtr))) + // Clang and GCC require attributes on each function into which SIMD intrinsics // are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and // automatic annotation via pragmas. @@ -214,6 +251,12 @@ namespace hwy { // 4 instances of a given literal value, useful as input to LoadDup128. #define HWY_REP4(literal) literal, literal, literal, literal +HWY_DLLEXPORT void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...); + +#define HWY_WARN(format, ...) \ + ::hwy::Warn(__FILE__, __LINE__, format, ##__VA_ARGS__) + HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) Abort(const char* file, int line, const char* format, ...); @@ -221,31 +264,49 @@ HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) // Always enabled. -#define HWY_ASSERT(condition) \ - do { \ - if (!(condition)) { \ - HWY_ABORT("Assert %s", #condition); \ - } \ +#define HWY_ASSERT_M(condition, msg) \ + do { \ + if (!(condition)) { \ + HWY_ABORT("Assert %s: %s", #condition, msg); \ + } \ } while (0) +#define HWY_ASSERT(condition) HWY_ASSERT_M(condition, "") -#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) +#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) || \ + defined(__SANITIZE_MEMORY__) #define HWY_IS_MSAN 1 #else #define HWY_IS_MSAN 0 #endif -#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) +#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) || \ + defined(__SANITIZE_ADDRESS__) #define HWY_IS_ASAN 1 #else #define HWY_IS_ASAN 0 #endif -#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) +#if HWY_HAS_FEATURE(hwaddress_sanitizer) || defined(HWADDRESS_SANITIZER) || \ + defined(__SANITIZE_HWADDRESS__) +#define HWY_IS_HWASAN 1 +#else +#define HWY_IS_HWASAN 0 +#endif + +#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) || \ + defined(__SANITIZE_THREAD__) #define HWY_IS_TSAN 1 #else #define HWY_IS_TSAN 0 #endif +#if HWY_HAS_FEATURE(undefined_behavior_sanitizer) || \ + defined(UNDEFINED_BEHAVIOR_SANITIZER) +#define HWY_IS_UBSAN 1 +#else +#define HWY_IS_UBSAN 0 +#endif + // MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo. // You can disable MSAN by adding this attribute to the function that fails. #if HWY_IS_MSAN @@ -259,7 +320,8 @@ HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) // Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent // MSVC defines NDEBUG (if not, could instead check _DEBUG). #if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || HWY_IS_ASAN || \ - HWY_IS_MSAN || HWY_IS_TSAN || defined(__clang_analyzer__) + HWY_IS_HWASAN || HWY_IS_MSAN || HWY_IS_TSAN || HWY_IS_UBSAN || \ + defined(__clang_analyzer__) #define HWY_IS_DEBUG_BUILD 1 #else #define HWY_IS_DEBUG_BUILD 0 @@ -267,8 +329,12 @@ HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) #endif // HWY_IS_DEBUG_BUILD #if HWY_IS_DEBUG_BUILD -#define HWY_DASSERT(condition) HWY_ASSERT(condition) +#define HWY_DASSERT_M(condition, msg) HWY_ASSERT_M(condition, msg) +#define HWY_DASSERT(condition) HWY_ASSERT_M(condition, "") #else +#define HWY_DASSERT_M(condition, msg) \ + do { \ + } while (0) #define HWY_DASSERT(condition) \ do { \ } while (0) @@ -282,14 +348,12 @@ HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) #pragma intrinsic(memset) #endif -// The source/destination must not overlap/alias. template -HWY_API void CopyBytes(const From* from, To* to) { +HWY_API void CopyBytes(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { #if HWY_COMPILER_MSVC memcpy(to, from, kBytes); #else - __builtin_memcpy(static_cast(to), static_cast(from), - kBytes); + __builtin_memcpy(to, from, kBytes); #endif } @@ -331,7 +395,7 @@ HWY_API void ZeroBytes(void* to, size_t num_bytes) { #if HWY_ARCH_X86 static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 -#elif HWY_ARCH_RVV && defined(__riscv_v_intrinsic) && \ +#elif HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ __riscv_v_intrinsic >= 11000 // Not actually an upper bound on the size. static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; @@ -347,7 +411,7 @@ static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; // exceed the stack size. #if HWY_ARCH_X86 #define HWY_ALIGN_MAX alignas(64) -#elif HWY_ARCH_RVV && defined(__riscv_v_intrinsic) && \ +#elif HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ __riscv_v_intrinsic >= 11000 #define HWY_ALIGN_MAX alignas(8) // only elements need be aligned #else @@ -357,349 +421,11 @@ static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; //------------------------------------------------------------------------------ // Lane types -#pragma pack(push, 1) - -// float16_t load/store/conversion intrinsics are always supported on Armv8 and -// VFPv4 (except with MSVC). On Armv7 Clang requires __ARM_FP & 2; GCC requires -// -mfp16-format=ieee. -#if (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) || \ - (HWY_COMPILER_CLANG && defined(__ARM_FP) && (__ARM_FP & 2)) || \ - (HWY_COMPILER_GCC_ACTUAL && defined(__ARM_FP16_FORMAT_IEEE)) -#define HWY_NEON_HAVE_FLOAT16C 1 -#else -#define HWY_NEON_HAVE_FLOAT16C 0 -#endif - -// C11 extension ISO/IEC TS 18661-3:2015 but not supported on all targets. -// Required if HWY_HAVE_FLOAT16, i.e. RVV with zvfh or AVX3_SPR (with -// sufficiently new compiler supporting avx512fp16). Do not use on clang-cl, -// which is missing __extendhfsf2. -#if ((HWY_ARCH_RVV && defined(__riscv_zvfh) && HWY_COMPILER_CLANG) || \ - (HWY_ARCH_X86 && defined(__SSE2__) && \ - ((HWY_COMPILER_CLANG >= 1600 && !HWY_COMPILER_CLANGCL) || \ - HWY_COMPILER_GCC_ACTUAL >= 1200))) -#define HWY_HAVE_C11_FLOAT16 1 -#else -#define HWY_HAVE_C11_FLOAT16 0 -#endif - -// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available: -// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip. -#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SVE_BF16) -#define HWY_SVE_HAVE_BFLOAT16 1 -#else -#define HWY_SVE_HAVE_BFLOAT16 0 -#endif - -// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name -// by concatenating base type and bits. We use a wrapper class instead of a -// typedef to the native type to ensure that the same symbols, e.g. for VQSort, -// are generated regardless of F16 support; see #1684. -struct float16_t { -#if HWY_NEON_HAVE_FLOAT16C // ACLE's __fp16 - using Raw = __fp16; -#elif HWY_HAVE_C11_FLOAT16 // C11 _Float16 - using Raw = _Float16; -#elif __cplusplus > 202002L && defined(__STDCPP_FLOAT16_T__) // C++23 - using Raw = std::float16_t; -#else -#define HWY_EMULATE_FLOAT16 - using Raw = uint16_t; - Raw bits; -#endif // float16_t - -// When backed by a native type, ensure the wrapper behaves like the native -// type by forwarding all operators. Unfortunately it seems difficult to reuse -// this code in a base class, so we repeat it in bfloat16_t. -#ifndef HWY_EMULATE_FLOAT16 - Raw raw; - - float16_t() noexcept = default; - template - constexpr float16_t(T arg) noexcept : raw(static_cast(arg)) {} - float16_t& operator=(Raw arg) noexcept { - raw = arg; - return *this; - } - constexpr float16_t(const float16_t&) noexcept = default; - float16_t& operator=(const float16_t&) noexcept = default; - constexpr operator Raw() const noexcept { return raw; } - - template - float16_t& operator+=(T rhs) noexcept { - raw = static_cast(raw + rhs); - return *this; - } - - template - float16_t& operator-=(T rhs) noexcept { - raw = static_cast(raw - rhs); - return *this; - } - - template - float16_t& operator*=(T rhs) noexcept { - raw = static_cast(raw * rhs); - return *this; - } - - template - float16_t& operator/=(T rhs) noexcept { - raw = static_cast(raw / rhs); - return *this; - } - - float16_t operator--() noexcept { - raw = static_cast(raw - Raw{1}); - return *this; - } - - float16_t operator--(int) noexcept { - raw = static_cast(raw - Raw{1}); - return *this; - } - - float16_t operator++() noexcept { - raw = static_cast(raw + Raw{1}); - return *this; - } - - float16_t operator++(int) noexcept { - raw = static_cast(raw + Raw{1}); - return *this; - } - - constexpr float16_t operator-() const noexcept { - return float16_t(static_cast(-raw)); - } - constexpr float16_t operator+() const noexcept { return *this; } -#endif // HWY_EMULATE_FLOAT16 -}; - -#ifndef HWY_EMULATE_FLOAT16 -constexpr inline bool operator==(float16_t lhs, float16_t rhs) noexcept { - return lhs.raw == rhs.raw; -} -constexpr inline bool operator!=(float16_t lhs, float16_t rhs) noexcept { - return lhs.raw != rhs.raw; -} -constexpr inline bool operator<(float16_t lhs, float16_t rhs) noexcept { - return lhs.raw < rhs.raw; -} -constexpr inline bool operator<=(float16_t lhs, float16_t rhs) noexcept { - return lhs.raw <= rhs.raw; -} -constexpr inline bool operator>(float16_t lhs, float16_t rhs) noexcept { - return lhs.raw > rhs.raw; -} -constexpr inline bool operator>=(float16_t lhs, float16_t rhs) noexcept { - return lhs.raw >= rhs.raw; -} -#endif // HWY_EMULATE_FLOAT16 - -struct bfloat16_t { -#if HWY_SVE_HAVE_BFLOAT16 - using Raw = __bf16; -#elif __cplusplus >= 202100L && defined(__STDCPP_BFLOAT16_T__) // C++23 - using Raw = std::bfloat16_t; -#else -#define HWY_EMULATE_BFLOAT16 - using Raw = uint16_t; - Raw bits; -#endif - -#ifndef HWY_EMULATE_BFLOAT16 - Raw raw; - - bfloat16_t() noexcept = default; - template - constexpr bfloat16_t(T arg) noexcept : raw(static_cast(arg)) {} - bfloat16_t& operator=(Raw arg) noexcept { - raw = arg; - return *this; - } - constexpr bfloat16_t(const bfloat16_t&) noexcept = default; - bfloat16_t& operator=(const bfloat16_t&) noexcept = default; - constexpr operator Raw() const noexcept { return raw; } - - template - bfloat16_t& operator+=(T rhs) noexcept { - raw = static_cast(raw + rhs); - return *this; - } - - template - bfloat16_t& operator-=(T rhs) noexcept { - raw = static_cast(raw - rhs); - return *this; - } - - template - bfloat16_t& operator*=(T rhs) noexcept { - raw = static_cast(raw * rhs); - return *this; - } - - template - bfloat16_t& operator/=(T rhs) noexcept { - raw = static_cast(raw / rhs); - return *this; - } - - bfloat16_t operator--() noexcept { - raw = static_cast(raw - Raw{1}); - return *this; - } - - bfloat16_t operator--(int) noexcept { - raw = static_cast(raw - Raw{1}); - return *this; - } - - bfloat16_t operator++() noexcept { - raw = static_cast(raw + Raw{1}); - return *this; - } - - bfloat16_t operator++(int) noexcept { - raw = static_cast(raw + Raw{1}); - return *this; - } - - constexpr bfloat16_t operator-() const noexcept { - return bfloat16_t(static_cast(-raw)); - } - constexpr bfloat16_t operator+() const noexcept { return *this; } -#endif // HWY_EMULATE_BFLOAT16 -}; - -#ifndef HWY_EMULATE_BFLOAT16 -constexpr inline bool operator==(bfloat16_t lhs, bfloat16_t rhs) noexcept { - return lhs.raw == rhs.raw; -} -constexpr inline bool operator!=(bfloat16_t lhs, bfloat16_t rhs) noexcept { - return lhs.raw != rhs.raw; -} -constexpr inline bool operator<(bfloat16_t lhs, bfloat16_t rhs) noexcept { - return lhs.raw < rhs.raw; -} -constexpr inline bool operator<=(bfloat16_t lhs, bfloat16_t rhs) noexcept { - return lhs.raw <= rhs.raw; -} -constexpr inline bool operator>(bfloat16_t lhs, bfloat16_t rhs) noexcept { - return lhs.raw > rhs.raw; -} -constexpr inline bool operator>=(bfloat16_t lhs, bfloat16_t rhs) noexcept { - return lhs.raw >= rhs.raw; -} -#endif // HWY_EMULATE_BFLOAT16 - -#pragma pack(pop) - -HWY_API float F32FromF16(float16_t f16) { -#ifdef HWY_EMULATE_FLOAT16 - uint16_t bits16; - CopySameSize(&f16, &bits16); - const uint32_t sign = static_cast(bits16 >> 15); - const uint32_t biased_exp = (bits16 >> 10) & 0x1F; - const uint32_t mantissa = bits16 & 0x3FF; - - // Subnormal or zero - if (biased_exp == 0) { - const float subnormal = - (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); - return sign ? -subnormal : subnormal; - } - - // Normalized: convert the representation directly (faster than ldexp/tables). - const uint32_t biased_exp32 = biased_exp + (127 - 15); - const uint32_t mantissa32 = mantissa << (23 - 10); - const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; - - float result; - CopySameSize(&bits32, &result); - return result; -#else - return static_cast(f16); -#endif -} - -HWY_API float16_t F16FromF32(float f32) { -#ifdef HWY_EMULATE_FLOAT16 - uint32_t bits32; - CopySameSize(&f32, &bits32); - const uint32_t sign = bits32 >> 31; - const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; - const uint32_t mantissa32 = bits32 & 0x7FFFFF; - - const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); - - // Tiny or zero => zero. - float16_t out; - if (exp < -24) { - // restore original sign - const uint16_t bits = static_cast(sign << 15); - CopySameSize(&bits, &out); - return out; - } - - uint32_t biased_exp16, mantissa16; - - // exp = [-24, -15] => subnormal - if (exp < -14) { - biased_exp16 = 0; - const uint32_t sub_exp = static_cast(-14 - exp); - HWY_DASSERT(1 <= sub_exp && sub_exp < 11); - mantissa16 = static_cast((1u << (10 - sub_exp)) + - (mantissa32 >> (13 + sub_exp))); - } else { - // exp = [-14, 15] - biased_exp16 = static_cast(exp + 15); - HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); - mantissa16 = mantissa32 >> 13; - } - - HWY_DASSERT(mantissa16 < 1024); - const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; - HWY_DASSERT(bits16 < 0x10000); - const uint16_t narrowed = static_cast(bits16); // big-endian safe - CopySameSize(&narrowed, &out); - return out; -#else - return float16_t(static_cast(f32)); -#endif -} - -HWY_API float F32FromBF16(bfloat16_t bf) { - uint16_t bits16; - CopyBytes<2>(&bf, &bits16); - uint32_t bits = bits16; - bits <<= 16; - float f; - CopySameSize(&bits, &f); - return f; -} - -HWY_API float F32FromF16Mem(const void* ptr) { - float16_t f16; - CopyBytes<2>(ptr, &f16); - return F32FromF16(f16); -} - -HWY_API float F32FromBF16Mem(const void* ptr) { - bfloat16_t bf; - CopyBytes<2>(ptr, &bf); - return F32FromBF16(bf); -} - -HWY_API bfloat16_t BF16FromF32(float f) { - uint32_t bits; - CopySameSize(&f, &bits); - const uint16_t bits16 = static_cast(bits >> 16); - bfloat16_t bf; - CopySameSize(&bits16, &bf); - return bf; -} +// hwy::float16_t and hwy::bfloat16_t are forward declared here to allow +// BitCastScalar to be implemented before the implementations of the +// hwy::float16_t and hwy::bfloat16_t types +struct float16_t; +struct bfloat16_t; using float32_t = float; using float64_t = double; @@ -729,24 +455,6 @@ struct alignas(8) K32V32 { #pragma pack(pop) -#ifdef HWY_EMULATE_FLOAT16 - -static inline HWY_MAYBE_UNUSED bool operator<(const float16_t& a, - const float16_t& b) { - return F32FromF16(a) < F32FromF16(b); -} -// Required for std::greater. -static inline HWY_MAYBE_UNUSED bool operator>(const float16_t& a, - const float16_t& b) { - return F32FromF16(a) > F32FromF16(b); -} -static inline HWY_MAYBE_UNUSED bool operator==(const float16_t& a, - const float16_t& b) { - return F32FromF16(a) == F32FromF16(b); -} - -#endif // HWY_EMULATE_FLOAT16 - static inline HWY_MAYBE_UNUSED bool operator<(const uint128_t& a, const uint128_t& b) { return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi; @@ -761,6 +469,13 @@ static inline HWY_MAYBE_UNUSED bool operator==(const uint128_t& a, return a.lo == b.lo && a.hi == b.hi; } +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const uint128_t& n) { + return os << "[hi=" << n.hi << ",lo=" << n.lo << "]"; +} +#endif + static inline HWY_MAYBE_UNUSED bool operator<(const K64V64& a, const K64V64& b) { return a.key < b.key; @@ -775,6 +490,13 @@ static inline HWY_MAYBE_UNUSED bool operator==(const K64V64& a, return a.key == b.key; } +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const K64V64& n) { + return os << "[k=" << n.key << ",v=" << n.value << "]"; +} +#endif + static inline HWY_MAYBE_UNUSED bool operator<(const K32V32& a, const K32V32& b) { return a.key < b.key; @@ -789,6 +511,13 @@ static inline HWY_MAYBE_UNUSED bool operator==(const K32V32& a, return a.key == b.key; } +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const K32V32& n) { + return os << "[k=" << n.key << ",v=" << n.value << "]"; +} +#endif + //------------------------------------------------------------------------------ // Controlling overload resolution (SFINAE) @@ -817,6 +546,12 @@ HWY_API constexpr bool IsSame() { return IsSameT::value; } +// Returns whether T matches either of U1 or U2 +template +HWY_API constexpr bool IsSameEither() { + return IsSameT::value || IsSameT::value; +} + template struct IfT { using type = Then; @@ -830,93 +565,1441 @@ struct IfT { template using If = typename IfT::type; -// Insert into template/function arguments to enable this overload only for -// vectors of exactly, at most (LE), or more than (GT) this many bytes. -// -// As an example, checking for a total size of 16 bytes will match both -// Simd and Simd. -#define HWY_IF_V_SIZE(T, kN, bytes) \ - hwy::EnableIf* = nullptr -#define HWY_IF_V_SIZE_LE(T, kN, bytes) \ - hwy::EnableIf* = nullptr -#define HWY_IF_V_SIZE_GT(T, kN, bytes) \ - hwy::EnableIf<(kN * sizeof(T) > bytes)>* = nullptr +template +struct IsConstT { + enum { value = 0 }; +}; -#define HWY_IF_LANES(kN, lanes) hwy::EnableIf<(kN == lanes)>* = nullptr -#define HWY_IF_LANES_LE(kN, lanes) hwy::EnableIf<(kN <= lanes)>* = nullptr -#define HWY_IF_LANES_GT(kN, lanes) hwy::EnableIf<(kN > lanes)>* = nullptr +template +struct IsConstT { + enum { value = 1 }; +}; -#define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr -#define HWY_IF_SIGNED(T) \ - hwy::EnableIf() && !IsFloat() && !IsSpecialFloat()>* = \ - nullptr -#define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr -#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr -#define HWY_IF_FLOAT3264(T) hwy::EnableIf()>* = nullptr -#define HWY_IF_NOT_FLOAT3264(T) hwy::EnableIf()>* = nullptr -#define HWY_IF_SPECIAL_FLOAT(T) \ - hwy::EnableIf()>* = nullptr -#define HWY_IF_NOT_SPECIAL_FLOAT(T) \ - hwy::EnableIf()>* = nullptr -#define HWY_IF_FLOAT_OR_SPECIAL(T) \ - hwy::EnableIf() || hwy::IsSpecialFloat()>* = nullptr -#define HWY_IF_NOT_FLOAT_NOR_SPECIAL(T) \ - hwy::EnableIf() && !hwy::IsSpecialFloat()>* = nullptr +template +HWY_API constexpr bool IsConst() { + return IsConstT::value; +} -#define HWY_IF_T_SIZE(T, bytes) hwy::EnableIf* = nullptr -#define HWY_IF_NOT_T_SIZE(T, bytes) \ - hwy::EnableIf* = nullptr -// bit_array = 0x102 means 1 or 8 bytes. There is no NONE_OF because it sounds -// too similar. If you want the opposite of this (2 or 4 bytes), ask for those -// bits explicitly (0x14) instead of attempting to 'negate' 0x102. -#define HWY_IF_T_SIZE_ONE_OF(T, bit_array) \ - hwy::EnableIf<((size_t{1} << sizeof(T)) & (bit_array)) != 0>* = nullptr +template +struct RemoveConstT { + using type = T; +}; +template +struct RemoveConstT { + using type = T; +}; + +template +using RemoveConst = typename RemoveConstT::type; + +template +struct RemoveVolatileT { + using type = T; +}; +template +struct RemoveVolatileT { + using type = T; +}; + +template +using RemoveVolatile = typename RemoveVolatileT::type; + +template +struct RemoveRefT { + using type = T; +}; +template +struct RemoveRefT { + using type = T; +}; +template +struct RemoveRefT { + using type = T; +}; + +template +using RemoveRef = typename RemoveRefT::type; + +template +using RemoveCvRef = RemoveConst>>; + +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; + +template +using RemovePtr = typename RemovePtrT::type; + +// Insert into template/function arguments to enable this overload only for +// vectors of exactly, at most (LE), or more than (GT) this many bytes. +// +// As an example, checking for a total size of 16 bytes will match both +// Simd and Simd. +#define HWY_IF_V_SIZE(T, kN, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_V_SIZE_LE(T, kN, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_V_SIZE_GT(T, kN, bytes) \ + hwy::EnableIf<(kN * sizeof(T) > bytes)>* = nullptr + +#define HWY_IF_LANES(kN, lanes) hwy::EnableIf<(kN == lanes)>* = nullptr +#define HWY_IF_LANES_LE(kN, lanes) hwy::EnableIf<(kN <= lanes)>* = nullptr +#define HWY_IF_LANES_GT(kN, lanes) hwy::EnableIf<(kN > lanes)>* = nullptr + +#define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SIGNED(T) \ + hwy::EnableIf() && !hwy::IsFloat() && \ + !hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_FLOAT3264(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT3264(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SPECIAL_FLOAT(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_SPECIAL_FLOAT(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_IF_FLOAT_OR_SPECIAL(T) \ + hwy::EnableIf() || hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL(T) \ + hwy::EnableIf() && !hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_INTEGER(T) hwy::EnableIf()>* = nullptr + +#define HWY_IF_T_SIZE(T, bytes) hwy::EnableIf* = nullptr +#define HWY_IF_NOT_T_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr +// bit_array = 0x102 means 1 or 8 bytes. There is no NONE_OF because it sounds +// too similar. If you want the opposite of this (2 or 4 bytes), ask for those +// bits explicitly (0x14) instead of attempting to 'negate' 0x102. +#define HWY_IF_T_SIZE_ONE_OF(T, bit_array) \ + hwy::EnableIf<((size_t{1} << sizeof(T)) & (bit_array)) != 0>* = nullptr +#define HWY_IF_T_SIZE_LE(T, bytes) \ + hwy::EnableIf<(sizeof(T) <= (bytes))>* = nullptr +#define HWY_IF_T_SIZE_GT(T, bytes) \ + hwy::EnableIf<(sizeof(T) > (bytes))>* = nullptr + +#define HWY_IF_SAME(T, expected) \ + hwy::EnableIf, expected>()>* = nullptr +#define HWY_IF_NOT_SAME(T, expected) \ + hwy::EnableIf, expected>()>* = nullptr + +// One of two expected types +#define HWY_IF_SAME2(T, expected1, expected2) \ + hwy::EnableIf< \ + hwy::IsSameEither, expected1, expected2>()>* = \ + nullptr + +#define HWY_IF_U8(T) HWY_IF_SAME(T, uint8_t) +#define HWY_IF_U16(T) HWY_IF_SAME(T, uint16_t) +#define HWY_IF_U32(T) HWY_IF_SAME(T, uint32_t) +#define HWY_IF_U64(T) HWY_IF_SAME(T, uint64_t) + +#define HWY_IF_I8(T) HWY_IF_SAME(T, int8_t) +#define HWY_IF_I16(T) HWY_IF_SAME(T, int16_t) +#define HWY_IF_I32(T) HWY_IF_SAME(T, int32_t) +#define HWY_IF_I64(T) HWY_IF_SAME(T, int64_t) + +#define HWY_IF_BF16(T) HWY_IF_SAME(T, hwy::bfloat16_t) +#define HWY_IF_NOT_BF16(T) HWY_IF_NOT_SAME(T, hwy::bfloat16_t) + +#define HWY_IF_F16(T) HWY_IF_SAME(T, hwy::float16_t) +#define HWY_IF_NOT_F16(T) HWY_IF_NOT_SAME(T, hwy::float16_t) + +#define HWY_IF_F32(T) HWY_IF_SAME(T, float) +#define HWY_IF_F64(T) HWY_IF_SAME(T, double) + +// Use instead of HWY_IF_T_SIZE to avoid ambiguity with float16_t/float/double +// overloads. +#define HWY_IF_UI8(T) HWY_IF_SAME2(T, uint8_t, int8_t) +#define HWY_IF_UI16(T) HWY_IF_SAME2(T, uint16_t, int16_t) +#define HWY_IF_UI32(T) HWY_IF_SAME2(T, uint32_t, int32_t) +#define HWY_IF_UI64(T) HWY_IF_SAME2(T, uint64_t, int64_t) + +#define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ + hwy::EnableIf* = nullptr + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +template +class DeclValT { + private: + template + static URef TryAddRValRef(int); + template + static U TryAddRValRef(Arg); + + public: + using type = decltype(TryAddRValRef(0)); + enum { kDisableDeclValEvaluation = 1 }; +}; + +// hwy::DeclVal() can only be used in unevaluated contexts such as within an +// expression of a decltype specifier. + +// hwy::DeclVal() does not require that T have a public default constructor +template +HWY_API typename DeclValT::type DeclVal() noexcept { + static_assert(!DeclValT::kDisableDeclValEvaluation, + "DeclVal() cannot be used in an evaluated context"); +} + +template +struct IsArrayT { + enum { value = 0 }; +}; + +template +struct IsArrayT { + enum { value = 1 }; +}; + +template +struct IsArrayT { + enum { value = 1 }; +}; + +template +static constexpr bool IsArray() { + return IsArrayT::value; +} + +#if HWY_COMPILER_MSVC +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4180, ignored "-Wignored-qualifiers") +#endif + +template +class IsConvertibleT { + private: + template + static hwy::SizeTag<1> TestFuncWithToArg(T); + + template + static decltype(IsConvertibleT::template TestFuncWithToArg( + DeclVal())) + TryConvTest(int); + + template + static hwy::SizeTag<0> TryConvTest(Arg); + + public: + enum { + value = (IsSame>, void>() && + IsSame>, void>()) || + (!IsArray() && + (IsSame())>() || + !IsSame, RemoveConst>()) && + IsSame(0)), hwy::SizeTag<1>>()) + }; +}; + +#if HWY_COMPILER_MSVC +HWY_DIAGNOSTICS(pop) +#endif + +template +HWY_API constexpr bool IsConvertible() { + return IsConvertibleT::value; +} + +template +class IsStaticCastableT { + private: + template (DeclVal()))> + static hwy::SizeTag<1> TryStaticCastTest(int); + + template + static hwy::SizeTag<0> TryStaticCastTest(Arg); + + public: + enum { + value = IsSame(0)), hwy::SizeTag<1>>() + }; +}; + +template +static constexpr bool IsStaticCastable() { + return IsStaticCastableT::value; +} + +#define HWY_IF_CASTABLE(From, To) \ + hwy::EnableIf()>* = nullptr + +#define HWY_IF_OP_CASTABLE(op, T, Native) \ + HWY_IF_CASTABLE(decltype(DeclVal() op DeclVal()), Native) + +template +class IsAssignableT { + private: + template () = DeclVal())> + static hwy::SizeTag<1> TryAssignTest(int); + + template + static hwy::SizeTag<0> TryAssignTest(Arg); + + public: + enum { + value = IsSame(0)), hwy::SizeTag<1>>() + }; +}; + +template +static constexpr bool IsAssignable() { + return IsAssignableT::value; +} + +#define HWY_IF_ASSIGNABLE(T, From) \ + hwy::EnableIf()>* = nullptr + +// ---------------------------------------------------------------------------- +// IsSpecialFloat + +// These types are often special-cased and not supported in all ops. +template +HWY_API constexpr bool IsSpecialFloat() { + return IsSameEither, hwy::float16_t, hwy::bfloat16_t>(); +} + +// ----------------------------------------------------------------------------- +// IsIntegerLaneType and IsInteger + +template +HWY_API constexpr bool IsIntegerLaneType() { + return false; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} + +namespace detail { + +template +static HWY_INLINE constexpr bool IsNonCvInteger() { + // NOTE: Do not add a IsNonCvInteger() specialization below as it is + // possible for IsSame() to be true when compiled with MSVC + // with the /Zc:wchar_t- option. + return IsIntegerLaneType() || IsSame() || + IsSameEither() || + IsSameEither(); +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +#if defined(__cpp_char8_t) && __cpp_char8_t >= 201811L +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +#endif +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} + +} // namespace detail + +template +HWY_API constexpr bool IsInteger() { + return detail::IsNonCvInteger>(); +} + +// ----------------------------------------------------------------------------- +// BitCastScalar + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +#define HWY_BITCASTSCALAR_CONSTEXPR constexpr +#else +#define HWY_BITCASTSCALAR_CONSTEXPR +#endif + +#if __cpp_constexpr >= 201304L +#define HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR +#else +#define HWY_BITCASTSCALAR_CXX14_CONSTEXPR +#endif + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template +struct BitCastScalarSrcCastHelper { + static HWY_INLINE constexpr const From& CastSrcValRef(const From& val) { + return val; + } +}; + +#if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 +// Workaround for Clang 9 constexpr __builtin_bit_cast bug +template >() && + hwy::IsInteger>()>* = nullptr> +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To +BuiltinBitCastScalar(const From& val) { + static_assert(sizeof(To) == sizeof(From), + "sizeof(To) == sizeof(From) must be true"); + return static_cast(val); +} + +template >() && + hwy::IsInteger>())>* = nullptr> +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To +BuiltinBitCastScalar(const From& val) { + return __builtin_bit_cast(To, val); +} +#endif // HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + // If From is hwy::float16_t or hwy::bfloat16_t, first cast val to either + // const typename From::Native& or const uint16_t& using + // detail::BitCastScalarSrcCastHelper>::CastSrcValRef to + // allow BitCastScalar from hwy::float16_t or hwy::bfloat16_t to be constexpr + // if To is not a pointer type, union type, or a struct/class containing a + // pointer, union, or reference subobject +#if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 + return detail::BuiltinBitCastScalar( + detail::BitCastScalarSrcCastHelper>::CastSrcValRef( + val)); +#else + return __builtin_bit_cast( + To, detail::BitCastScalarSrcCastHelper>::CastSrcValRef( + val)); +#endif +} +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + // If To is hwy::float16_t or hwy::bfloat16_t, first do a BitCastScalar of val + // to uint16_t, and then bit cast the uint16_t value to To using To::FromBits + // as hwy::float16_t::FromBits and hwy::bfloat16_t::FromBits are guaranteed to + // be constexpr if the __builtin_bit_cast intrinsic is available. + return To::FromBits(BitCastScalar(val)); +} +#else +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + To result; + CopySameSize(&val, &result); + return result; +} +#endif + +//------------------------------------------------------------------------------ +// F16 lane type + +#pragma pack(push, 1) + +// Compiler supports __fp16 and load/store/conversion NEON intrinsics, which are +// included in Armv8 and VFPv4 (except with MSVC). On Armv7 Clang requires +// __ARM_FP & 2 whereas Armv7 GCC requires -mfp16-format=ieee. +#if (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) || \ + (HWY_COMPILER_CLANG && defined(__ARM_FP) && (__ARM_FP & 2)) || \ + (HWY_COMPILER_GCC_ACTUAL && defined(__ARM_FP16_FORMAT_IEEE)) +#define HWY_NEON_HAVE_F16C 1 +#else +#define HWY_NEON_HAVE_F16C 0 +#endif + +// RVV with f16 extension supports _Float16 and f16 vector ops. If set, implies +// HWY_HAVE_FLOAT16. +#if HWY_ARCH_RISCV && defined(__riscv_zvfh) && HWY_COMPILER_CLANG >= 1600 +#define HWY_RVV_HAVE_F16_VEC 1 +#else +#define HWY_RVV_HAVE_F16_VEC 0 +#endif + +// x86 compiler supports _Float16, not necessarily with operators. +// Avoid clang-cl because it lacks __extendhfsf2. +#if HWY_ARCH_X86 && defined(__SSE2__) && defined(__FLT16_MAX__) && \ + ((HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL) || \ + HWY_COMPILER_GCC_ACTUAL >= 1200) +#define HWY_SSE2_HAVE_F16_TYPE 1 +#else +#define HWY_SSE2_HAVE_F16_TYPE 0 +#endif + +#ifndef HWY_HAVE_SCALAR_F16_TYPE +// Compiler supports _Float16, not necessarily with operators. +#if HWY_NEON_HAVE_F16C || HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE +#define HWY_HAVE_SCALAR_F16_TYPE 1 +#else +#define HWY_HAVE_SCALAR_F16_TYPE 0 +#endif +#endif // HWY_HAVE_SCALAR_F16_TYPE + +#ifndef HWY_HAVE_SCALAR_F16_OPERATORS +// Recent enough compiler also has operators. +#if HWY_HAVE_SCALAR_F16_TYPE && \ + (HWY_COMPILER_CLANG >= 1800 || HWY_COMPILER_GCC_ACTUAL >= 1200 || \ + (HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL && \ + !defined(_WIN32)) || \ + (HWY_ARCH_ARM && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800))) +#define HWY_HAVE_SCALAR_F16_OPERATORS 1 +#else +#define HWY_HAVE_SCALAR_F16_OPERATORS 0 +#endif +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +namespace detail { + +template , bool = IsSpecialFloat()> +struct SpecialFloatUnwrapArithOpOperandT {}; + +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = T; +}; + +template +using SpecialFloatUnwrapArithOpOperand = + typename SpecialFloatUnwrapArithOpOperandT::type; + +template > +struct NativeSpecialFloatToWrapperT { + using type = T; +}; + +template +using NativeSpecialFloatToWrapper = + typename NativeSpecialFloatToWrapperT::type; + +} // namespace detail + +// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name +// by concatenating base type and bits. We use a wrapper class instead of a +// typedef to the native type to ensure that the same symbols, e.g. for VQSort, +// are generated regardless of F16 support; see #1684. +struct alignas(2) float16_t { +#if HWY_HAVE_SCALAR_F16_TYPE +#if HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE + using Native = _Float16; +#elif HWY_NEON_HAVE_F16C + using Native = __fp16; +#else +#error "Logic error: condition should be 'all but NEON_HAVE_F16C'" +#endif +#elif HWY_IDE + using Native = uint16_t; +#endif // HWY_HAVE_SCALAR_F16_TYPE + + union { +#if HWY_HAVE_SCALAR_F16_TYPE || HWY_IDE + // Accessed via NativeLaneType, and used directly if + // HWY_HAVE_SCALAR_F16_OPERATORS. + Native native; +#endif + // Only accessed via NativeLaneType or U16LaneType. + uint16_t bits; + }; + + // Default init and copying. + float16_t() noexcept = default; + constexpr float16_t(const float16_t&) noexcept = default; + constexpr float16_t(float16_t&&) noexcept = default; + float16_t& operator=(const float16_t&) noexcept = default; + float16_t& operator=(float16_t&&) noexcept = default; + +#if HWY_HAVE_SCALAR_F16_TYPE + // NEON vget/set_lane intrinsics and SVE `svaddv` could use explicit + // float16_t(intrinsic()), but user code expects implicit conversions. + constexpr float16_t(Native arg) noexcept : native(arg) {} + constexpr operator Native() const noexcept { return native; } +#endif + +#if HWY_HAVE_SCALAR_F16_TYPE + static HWY_BITCASTSCALAR_CONSTEXPR float16_t FromBits(uint16_t bits) { + return float16_t(BitCastScalar(bits)); + } +#else + + private: + struct F16FromU16BitsTag {}; + constexpr float16_t(F16FromU16BitsTag /*tag*/, uint16_t u16_bits) + : bits(u16_bits) {} + + public: + static constexpr float16_t FromBits(uint16_t bits) { + return float16_t(F16FromU16BitsTag(), bits); + } +#endif + + // When backed by a native type, ensure the wrapper behaves like the native + // type by forwarding all operators. Unfortunately it seems difficult to reuse + // this code in a base class, so we repeat it in float16_t. +#if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + template , float16_t>() && + IsConvertible()>* = nullptr> + constexpr float16_t(T&& arg) noexcept + : native(static_cast(static_cast(arg))) {} + + template , float16_t>() && + !IsConvertible() && + IsStaticCastable()>* = nullptr> + explicit constexpr float16_t(T&& arg) noexcept + : native(static_cast(static_cast(arg))) {} + + // pre-decrement operator (--x) + HWY_CXX14_CONSTEXPR float16_t& operator--() noexcept { + native = static_cast(native - Native{1}); + return *this; + } + + // post-decrement operator (x--) + HWY_CXX14_CONSTEXPR float16_t operator--(int) noexcept { + float16_t result = *this; + native = static_cast(native - Native{1}); + return result; + } + + // pre-increment operator (++x) + HWY_CXX14_CONSTEXPR float16_t& operator++() noexcept { + native = static_cast(native + Native{1}); + return *this; + } + + // post-increment operator (x++) + HWY_CXX14_CONSTEXPR float16_t operator++(int) noexcept { + float16_t result = *this; + native = static_cast(native + Native{1}); + return result; + } + + constexpr float16_t operator-() const noexcept { + return float16_t(static_cast(-native)); + } + constexpr float16_t operator+() const noexcept { return *this; } + + // Reduce clutter by generating `operator+` and `operator+=` etc. Note that + // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. +#define HWY_FLOAT16_BINARY_OP(op, op_func, assign_func) \ + constexpr float16_t op_func(const float16_t& rhs) const noexcept { \ + return float16_t(static_cast(native op rhs.native)); \ + } \ + template , \ + typename RawResultT = \ + decltype(DeclVal() op DeclVal()), \ + typename ResultT = \ + detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + return static_cast(native op static_cast(rhs)); \ + } \ + HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func( \ + const hwy::float16_t& rhs) noexcept { \ + native = static_cast(native op rhs.native); \ + return *this; \ + } \ + template () op DeclVal()))> \ + HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func(const T& rhs) noexcept( \ + noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + native = static_cast(native op rhs); \ + return *this; \ + } + + HWY_FLOAT16_BINARY_OP(+, operator+, operator+=) + HWY_FLOAT16_BINARY_OP(-, operator-, operator-=) + HWY_FLOAT16_BINARY_OP(*, operator*, operator*=) + HWY_FLOAT16_BINARY_OP(/, operator/, operator/=) +#undef HWY_FLOAT16_BINARY_OP + +#endif // HWY_HAVE_SCALAR_F16_OPERATORS +}; +static_assert(sizeof(hwy::float16_t) == 2, "Wrong size of float16_t"); + +#if HWY_HAVE_SCALAR_F16_TYPE +namespace detail { + +#if HWY_HAVE_SCALAR_F16_OPERATORS +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = hwy::float16_t::Native; +}; +#endif + +template +struct NativeSpecialFloatToWrapperT { + using type = hwy::float16_t; +}; + +} // namespace detail +#endif // HWY_HAVE_SCALAR_F16_TYPE + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template <> +struct BitCastScalarSrcCastHelper { +#if HWY_HAVE_SCALAR_F16_TYPE + static HWY_INLINE constexpr const hwy::float16_t::Native& CastSrcValRef( + const hwy::float16_t& val) { + return val.native; + } +#else + static HWY_INLINE constexpr const uint16_t& CastSrcValRef( + const hwy::float16_t& val) { + return val.bits; + } +#endif +}; + +} // namespace detail +#endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 + +#if HWY_HAVE_SCALAR_F16_OPERATORS +#define HWY_F16_CONSTEXPR constexpr +#else +#define HWY_F16_CONSTEXPR HWY_BITCASTSCALAR_CXX14_CONSTEXPR +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +HWY_API HWY_F16_CONSTEXPR float F32FromF16(float16_t f16) { +#if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE + return static_cast(f16); +#endif +#if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + const uint16_t bits16 = BitCastScalar(f16); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized, infinity or NaN: convert the representation directly + // (faster than ldexp/tables). + const uint32_t biased_exp32 = + biased_exp == 31 ? 0xFF : biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + return BitCastScalar(bits32); +#endif // !HWY_HAVE_SCALAR_F16_OPERATORS +} + +#if HWY_IS_DEBUG_BUILD && \ + (HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926) +#if defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L +// If C++23 if !consteval support is available, only execute +// HWY_DASSERT(condition) if F16FromF32 is not called from a constant-evaluated +// context to avoid compilation errors. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + if !consteval { \ + HWY_DASSERT(condition); \ + } \ + } while (0) +#elif HWY_HAS_BUILTIN(__builtin_is_constant_evaluated) || \ + HWY_COMPILER_MSVC >= 1926 +// If the __builtin_is_constant_evaluated() intrinsic is available, +// only do HWY_DASSERT(condition) if __builtin_is_constant_evaluated() returns +// false to avoid compilation errors if F16FromF32 is called from a +// constant-evaluated context. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + if (!__builtin_is_constant_evaluated()) { \ + HWY_DASSERT(condition); \ + } \ + } while (0) +#else +// If C++23 if !consteval support is not available, +// the __builtin_is_constant_evaluated() intrinsic is not available, +// HWY_IS_DEBUG_BUILD is 1, and the __builtin_bit_cast intrinsic is available, +// do not do a HWY_DASSERT to avoid compilation errors if F16FromF32 is +// called from a constant-evaluated context. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + } while (0) +#endif // defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L +#else +// If HWY_IS_DEBUG_BUILD is 0 or the __builtin_bit_cast intrinsic is not +// available, define HWY_F16_FROM_F32_DASSERT(condition) as +// HWY_DASSERT(condition) +#define HWY_F16_FROM_F32_DASSERT(condition) HWY_DASSERT(condition) +#endif // HWY_IS_DEBUG_BUILD && (HWY_HAS_BUILTIN(__builtin_bit_cast) || + // HWY_COMPILER_MSVC >= 1926) + +HWY_API HWY_F16_CONSTEXPR float16_t F16FromF32(float f32) { +#if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE + return float16_t(static_cast(f32)); +#endif +#if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + const uint32_t bits32 = BitCastScalar(f32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + constexpr uint32_t kMantissaMask = 0x7FFFFF; + const uint32_t mantissa32 = bits32 & kMantissaMask; + + // Before shifting (truncation), round to nearest even to reduce bias. If + // the lowest remaining mantissa bit is odd, increase the offset. Example + // with the lowest remaining bit (left) and next lower two bits; the + // latter, plus two more, will be truncated. + // 0[00] + 1 = 0[01] + // 0[01] + 1 = 0[10] + // 0[10] + 1 = 0[11] (round down toward even) + // 0[11] + 1 = 1[00] (round up) + // 1[00] + 10 = 1[10] + // 1[01] + 10 = 1[11] + // 1[10] + 10 = C0[00] (round up toward even with C=1 carry out) + // 1[11] + 10 = C0[01] (round up toward even with C=1 carry out) + + // If |f32| >= 2^-24, f16_ulp_bit_idx is the index of the F32 mantissa bit + // that will be shifted down into the ULP bit of the rounded down F16 result + + // The biased F32 exponent of 2^-14 (the smallest positive normal F16 value) + // is 113, and bit 13 of the F32 mantissa will be shifted down to into the ULP + // bit of the rounded down F16 result if |f32| >= 2^14 + + // If |f32| < 2^-24, f16_ulp_bit_idx is equal to 24 as there are 24 mantissa + // bits (including the implied 1 bit) in the mantissa of a normal F32 value + // and as we want to round up the mantissa if |f32| > 2^-25 && |f32| < 2^-24 + const int32_t f16_ulp_bit_idx = + HWY_MIN(HWY_MAX(126 - static_cast(biased_exp32), 13), 24); + const uint32_t odd_bit = ((mantissa32 | 0x800000u) >> f16_ulp_bit_idx) & 1; + const uint32_t rounded = + mantissa32 + odd_bit + (uint32_t{1} << (f16_ulp_bit_idx - 1)) - 1u; + const bool carry = rounded >= (1u << 23); + + const int32_t exp = static_cast(biased_exp32) - 127 + carry; + + // Tiny or zero => zero. + if (exp < -24) { + // restore original sign + return float16_t::FromBits(static_cast(sign << 15)); + } + + // If biased_exp16 would be >= 31, first check whether the input was NaN so we + // can set the mantissa to nonzero. + const bool is_nan = (biased_exp32 == 255) && mantissa32 != 0; + const bool overflowed = exp >= 16; + const uint32_t biased_exp16 = + static_cast(HWY_MIN(HWY_MAX(0, exp + 15), 31)); + // exp = [-24, -15] => subnormal, shift the mantissa. + const uint32_t sub_exp = static_cast(HWY_MAX(-14 - exp, 0)); + HWY_F16_FROM_F32_DASSERT(sub_exp < 11); + const uint32_t shifted_mantissa = + (rounded & kMantissaMask) >> (23 - 10 + sub_exp); + const uint32_t leading = sub_exp == 0u ? 0u : (1024u >> sub_exp); + const uint32_t mantissa16 = is_nan ? 0x3FF + : overflowed ? 0u + : (leading + shifted_mantissa); + +#if HWY_IS_DEBUG_BUILD + if (exp < -14) { + HWY_F16_FROM_F32_DASSERT(biased_exp16 == 0); + HWY_F16_FROM_F32_DASSERT(sub_exp >= 1); + } else if (exp <= 15) { + HWY_F16_FROM_F32_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + HWY_F16_FROM_F32_DASSERT(sub_exp == 0); + } +#endif + + HWY_F16_FROM_F32_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_F16_FROM_F32_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + return float16_t::FromBits(narrowed); +#endif // !HWY_HAVE_SCALAR_F16_OPERATORS +} + +HWY_API HWY_F16_CONSTEXPR float16_t F16FromF64(double f64) { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return float16_t(static_cast(f64)); +#else + // The mantissa bits of f64 are first rounded using round-to-odd rounding + // to the nearest f64 value that has the lower 29 bits zeroed out to + // ensure that the result is correctly rounded to a F16. + + // The F64 round-to-odd operation below will round a normal F64 value + // (using round-to-odd rounding) to a F64 value that has 24 bits of precision. + + // It is okay if the magnitude of a denormal F64 value is rounded up in the + // F64 round-to-odd step below as the magnitude of a denormal F64 value is + // much smaller than 2^(-24) (the smallest positive denormal F16 value). + + // It is also okay if bit 29 of a NaN F64 value is changed by the F64 + // round-to-odd step below as the lower 13 bits of a F32 NaN value are usually + // discarded or ignored by the conversion of a F32 NaN value to a F16. + + // If f64 is a NaN value, the result of the F64 round-to-odd step will be a + // NaN value as the result of the F64 round-to-odd step will have at least one + // mantissa bit if f64 is a NaN value. + + // The F64 round-to-odd step will ensure that the F64 to F32 conversion is + // exact if the magnitude of the rounded F64 value (using round-to-odd + // rounding) is between 2^(-126) (the smallest normal F32 value) and + // HighestValue() (the largest finite F32 value) + + // It is okay if the F64 to F32 conversion is inexact for F64 values that have + // a magnitude that is less than 2^(-126) as the magnitude of a denormal F32 + // value is much smaller than 2^(-24) (the smallest positive denormal F16 + // value). + + return F16FromF32( + static_cast(BitCastScalar(static_cast( + (BitCastScalar(f64) & 0xFFFFFFFFE0000000ULL) | + ((BitCastScalar(f64) + 0x000000001FFFFFFFULL) & + 0x0000000020000000ULL))))); +#endif +} + +// More convenient to define outside float16_t because these may use +// F32FromF16, which is defined after the struct. +HWY_F16_CONSTEXPR inline bool operator==(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native == rhs.native; +#else + return F32FromF16(lhs) == F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator!=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native != rhs.native; +#else + return F32FromF16(lhs) != F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator<(float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native < rhs.native; +#else + return F32FromF16(lhs) < F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator<=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native <= rhs.native; +#else + return F32FromF16(lhs) <= F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator>(float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native > rhs.native; +#else + return F32FromF16(lhs) > F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator>=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native >= rhs.native; +#else + return F32FromF16(lhs) >= F32FromF16(rhs); +#endif +} +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_F16_CONSTEXPR inline std::partial_ordering operator<=>( + float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native <=> rhs.native; +#else + return F32FromF16(lhs) <=> F32FromF16(rhs); +#endif +} +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +//------------------------------------------------------------------------------ +// BF16 lane type + +// Compiler supports ACLE __bf16, not necessarily with operators. + +// Disable the __bf16 type on AArch64 with GCC 13 or earlier as there is a bug +// in GCC 13 and earlier that sometimes causes BF16 constant values to be +// incorrectly loaded on AArch64, and this GCC bug on AArch64 is +// described at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111867. + +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400) +#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 0 +#endif + +// x86 compiler supports __bf16, not necessarily with operators. +#ifndef HWY_SSE2_HAVE_SCALAR_BF16_TYPE +#if HWY_ARCH_X86 && defined(__SSE2__) && \ + ((HWY_COMPILER_CLANG >= 1700 && !HWY_COMPILER_CLANGCL) || \ + HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 0 +#endif +#endif // HWY_SSE2_HAVE_SCALAR_BF16_TYPE + +// Compiler supports __bf16, not necessarily with operators. +#if HWY_ARM_HAVE_SCALAR_BF16_TYPE || HWY_SSE2_HAVE_SCALAR_BF16_TYPE +#define HWY_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_HAVE_SCALAR_BF16_TYPE 0 +#endif + +#ifndef HWY_HAVE_SCALAR_BF16_OPERATORS +// Recent enough compiler also has operators. aarch64 clang 18 hits internal +// compiler errors on bf16 ToString, hence only enable on GCC for now. +#if HWY_HAVE_SCALAR_BF16_TYPE && (HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_HAVE_SCALAR_BF16_OPERATORS 1 +#else +#define HWY_HAVE_SCALAR_BF16_OPERATORS 0 +#endif +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +#define HWY_BF16_CONSTEXPR constexpr +#else +#define HWY_BF16_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR +#endif + +struct alignas(2) bfloat16_t { +#if HWY_HAVE_SCALAR_BF16_TYPE + using Native = __bf16; +#elif HWY_IDE + using Native = uint16_t; +#endif + + union { +#if HWY_HAVE_SCALAR_BF16_TYPE || HWY_IDE + // Accessed via NativeLaneType, and used directly if + // HWY_HAVE_SCALAR_BF16_OPERATORS. + Native native; +#endif + // Only accessed via NativeLaneType or U16LaneType. + uint16_t bits; + }; + + // Default init and copying + bfloat16_t() noexcept = default; + constexpr bfloat16_t(bfloat16_t&&) noexcept = default; + constexpr bfloat16_t(const bfloat16_t&) noexcept = default; + bfloat16_t& operator=(bfloat16_t&& arg) noexcept = default; + bfloat16_t& operator=(const bfloat16_t& arg) noexcept = default; + +// Only enable implicit conversions if we have a native type. +#if HWY_HAVE_SCALAR_BF16_TYPE || HWY_IDE + constexpr bfloat16_t(Native arg) noexcept : native(arg) {} + constexpr operator Native() const noexcept { return native; } +#endif + +#if HWY_HAVE_SCALAR_BF16_TYPE + static HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t FromBits(uint16_t bits) { + return bfloat16_t(BitCastScalar(bits)); + } +#else + + private: + struct BF16FromU16BitsTag {}; + constexpr bfloat16_t(BF16FromU16BitsTag /*tag*/, uint16_t u16_bits) + : bits(u16_bits) {} + + public: + static constexpr bfloat16_t FromBits(uint16_t bits) { + return bfloat16_t(BF16FromU16BitsTag(), bits); + } +#endif + + // When backed by a native type, ensure the wrapper behaves like the native + // type by forwarding all operators. Unfortunately it seems difficult to reuse + // this code in a base class, so we repeat it in float16_t. +#if HWY_HAVE_SCALAR_BF16_OPERATORS || HWY_IDE + template , Native>() && + !IsSame, bfloat16_t>() && + IsConvertible()>* = nullptr> + constexpr bfloat16_t(T&& arg) noexcept( + noexcept(static_cast(DeclVal()))) + : native(static_cast(static_cast(arg))) {} + + template , Native>() && + !IsSame, bfloat16_t>() && + !IsConvertible() && + IsStaticCastable()>* = nullptr> + explicit constexpr bfloat16_t(T&& arg) noexcept( + noexcept(static_cast(DeclVal()))) + : native(static_cast(static_cast(arg))) {} + + HWY_CXX14_CONSTEXPR bfloat16_t& operator=(Native arg) noexcept { + native = arg; + return *this; + } + + // pre-decrement operator (--x) + HWY_CXX14_CONSTEXPR bfloat16_t& operator--() noexcept { + native = static_cast(native - Native{1}); + return *this; + } + + // post-decrement operator (x--) + HWY_CXX14_CONSTEXPR bfloat16_t operator--(int) noexcept { + bfloat16_t result = *this; + native = static_cast(native - Native{1}); + return result; + } -// Use instead of HWY_IF_T_SIZE to avoid ambiguity with float16_t/float/double -// overloads. -#define HWY_IF_UI16(T) \ - hwy::EnableIf() || IsSame()>* = nullptr -#define HWY_IF_UI32(T) \ - hwy::EnableIf() || IsSame()>* = nullptr -#define HWY_IF_UI64(T) \ - hwy::EnableIf() || IsSame()>* = nullptr -#define HWY_IF_BF16(T) hwy::EnableIf()>* = nullptr -#define HWY_IF_F16(T) hwy::EnableIf()>* = nullptr + // pre-increment operator (++x) + HWY_CXX14_CONSTEXPR bfloat16_t& operator++() noexcept { + native = static_cast(native + Native{1}); + return *this; + } -#define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ - hwy::EnableIf* = nullptr + // post-increment operator (x++) + HWY_CXX14_CONSTEXPR bfloat16_t operator++(int) noexcept { + bfloat16_t result = *this; + native = static_cast(native + Native{1}); + return result; + } -// Empty struct used as a size tag type. -template -struct SizeTag {}; + constexpr bfloat16_t operator-() const noexcept { + return bfloat16_t(static_cast(-native)); + } + constexpr bfloat16_t operator+() const noexcept { return *this; } -template -struct RemoveConstT { - using type = T; -}; -template -struct RemoveConstT { - using type = T; + // Reduce clutter by generating `operator+` and `operator+=` etc. Note that + // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. +#define HWY_BFLOAT16_BINARY_OP(op, op_func, assign_func) \ + constexpr bfloat16_t op_func(const bfloat16_t& rhs) const noexcept { \ + return bfloat16_t(static_cast(native op rhs.native)); \ + } \ + template , \ + typename RawResultT = \ + decltype(DeclVal() op DeclVal()), \ + typename ResultT = \ + detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + return static_cast(native op static_cast(rhs)); \ + } \ + HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func( \ + const hwy::bfloat16_t& rhs) noexcept { \ + native = static_cast(native op rhs.native); \ + return *this; \ + } \ + template () op DeclVal()))> \ + HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func(const T& rhs) noexcept( \ + noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + native = static_cast(native op rhs); \ + return *this; \ + } + HWY_BFLOAT16_BINARY_OP(+, operator+, operator+=) + HWY_BFLOAT16_BINARY_OP(-, operator-, operator-=) + HWY_BFLOAT16_BINARY_OP(*, operator*, operator*=) + HWY_BFLOAT16_BINARY_OP(/, operator/, operator/=) +#undef HWY_BFLOAT16_BINARY_OP + +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS }; +static_assert(sizeof(hwy::bfloat16_t) == 2, "Wrong size of bfloat16_t"); -template -using RemoveConst = typename RemoveConstT::type; +#pragma pack(pop) + +#if HWY_HAVE_SCALAR_BF16_TYPE +namespace detail { +#if HWY_HAVE_SCALAR_BF16_OPERATORS template -struct RemoveRefT { - using type = T; +struct SpecialFloatUnwrapArithOpOperandT { + using type = hwy::bfloat16_t::Native; }; +#endif + template -struct RemoveRefT { - using type = T; +struct NativeSpecialFloatToWrapperT { + using type = hwy::bfloat16_t; }; -template -struct RemoveRefT { - using type = T; + +} // namespace detail +#endif // HWY_HAVE_SCALAR_BF16_TYPE + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template <> +struct BitCastScalarSrcCastHelper { +#if HWY_HAVE_SCALAR_BF16_TYPE + static HWY_INLINE constexpr const hwy::bfloat16_t::Native& CastSrcValRef( + const hwy::bfloat16_t& val) { + return val.native; + } +#else + static HWY_INLINE constexpr const uint16_t& CastSrcValRef( + const hwy::bfloat16_t& val) { + return val.bits; + } +#endif }; -template -using RemoveRef = typename RemoveRefT::type; +} // namespace detail +#endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 + +HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf) { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return static_cast(bf); +#else + return BitCastScalar(static_cast( + static_cast(BitCastScalar(bf)) << 16)); +#endif +} + +namespace detail { + +// Returns the increment to add to the bits of a finite F32 value to round a +// finite F32 to the nearest BF16 value +static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint32_t F32BitsToBF16RoundIncr( + const uint32_t f32_bits) { + return static_cast(((f32_bits & 0x7FFFFFFFu) < 0x7F800000u) + ? (0x7FFFu + ((f32_bits >> 16) & 1u)) + : 0u); +} + +// Converts f32_bits (which is the bits of a F32 value) to BF16 bits, +// rounded to the nearest F16 value +static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint16_t F32BitsToBF16Bits( + const uint32_t f32_bits) { + // Round f32_bits to the nearest BF16 by first adding + // F32BitsToBF16RoundIncr(f32_bits) to f32_bits and then right shifting + // f32_bits + F32BitsToBF16RoundIncr(f32_bits) by 16 + + // If f32_bits is the bit representation of a NaN F32 value, make sure that + // bit 6 of the BF16 result is set to convert SNaN F32 values to QNaN BF16 + // values and to prevent NaN F32 values from being converted to an infinite + // BF16 value + return static_cast( + ((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16) | + (static_cast((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) << 6)); +} + +} // namespace detail + +HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF32(float f) { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return static_cast(f); +#else + return bfloat16_t::FromBits( + detail::F32BitsToBF16Bits(BitCastScalar(f))); +#endif +} + +HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF64(double f64) { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return static_cast(f64); +#else + // The mantissa bits of f64 are first rounded using round-to-odd rounding + // to the nearest f64 value that has the lower 38 bits zeroed out to + // ensure that the result is correctly rounded to a BF16. + + // The F64 round-to-odd operation below will round a normal F64 value + // (using round-to-odd rounding) to a F64 value that has 15 bits of precision. + + // It is okay if the magnitude of a denormal F64 value is rounded up in the + // F64 round-to-odd step below as the magnitude of a denormal F64 value is + // much smaller than 2^(-133) (the smallest positive denormal BF16 value). + + // It is also okay if bit 38 of a NaN F64 value is changed by the F64 + // round-to-odd step below as the lower 16 bits of a F32 NaN value are usually + // discarded or ignored by the conversion of a F32 NaN value to a BF16. + + // If f64 is a NaN value, the result of the F64 round-to-odd step will be a + // NaN value as the result of the F64 round-to-odd step will have at least one + // mantissa bit if f64 is a NaN value. + + // The F64 round-to-odd step below will ensure that the F64 to F32 conversion + // is exact if the magnitude of the rounded F64 value (using round-to-odd + // rounding) is between 2^(-135) (one-fourth of the smallest positive denormal + // BF16 value) and HighestValue() (the largest finite F32 value). + + // If |f64| is less than 2^(-135), the magnitude of the result of the F64 to + // F32 conversion is guaranteed to be less than or equal to 2^(-135), which + // ensures that the F32 to BF16 conversion is correctly rounded, even if the + // conversion of a rounded F64 value whose magnitude is less than 2^(-135) + // to a F32 is inexact. + + return BF16FromF32( + static_cast(BitCastScalar(static_cast( + (BitCastScalar(f64) & 0xFFFFFFC000000000ULL) | + ((BitCastScalar(f64) + 0x0000003FFFFFFFFFULL) & + 0x0000004000000000ULL))))); +#endif +} + +// More convenient to define outside bfloat16_t because these may use +// F32FromBF16, which is defined after the struct. + +HWY_BF16_CONSTEXPR inline bool operator==(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native == rhs.native; +#else + return F32FromBF16(lhs) == F32FromBF16(rhs); +#endif +} + +HWY_BF16_CONSTEXPR inline bool operator!=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native != rhs.native; +#else + return F32FromBF16(lhs) != F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator<(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native < rhs.native; +#else + return F32FromBF16(lhs) < F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator<=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native <= rhs.native; +#else + return F32FromBF16(lhs) <= F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator>(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native > rhs.native; +#else + return F32FromBF16(lhs) > F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator>=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native >= rhs.native; +#else + return F32FromBF16(lhs) >= F32FromBF16(rhs); +#endif +} +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_BF16_CONSTEXPR inline std::partial_ordering operator<=>( + bfloat16_t lhs, bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native <=> rhs.native; +#else + return F32FromBF16(lhs) <=> F32FromBF16(rhs); +#endif +} +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE //------------------------------------------------------------------------------ // Type relations @@ -1110,25 +2193,19 @@ constexpr auto IsFloatTag() -> hwy::SizeTag<(R::is_float ? 0x200 : 0x400)> { template HWY_API constexpr bool IsFloat3264() { - return IsSame() || IsSame(); + return IsSameEither, float, double>(); } template HWY_API constexpr bool IsFloat() { // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or // from a float, not compared. Include float16_t in case HWY_HAVE_FLOAT16=1. - return IsSame() || IsFloat3264(); -} - -// These types are often special-cased and not supported in all ops. -template -HWY_API constexpr bool IsSpecialFloat() { - return IsSame() || IsSame(); + return IsSame, float16_t>() || IsFloat3264(); } template HWY_API constexpr bool IsSigned() { - return T(0) > T(-1); + return static_cast(0) > static_cast(-1); } template <> constexpr bool IsSigned() { @@ -1138,104 +2215,113 @@ template <> constexpr bool IsSigned() { return true; } +template <> +constexpr bool IsSigned() { + return false; +} +template <> +constexpr bool IsSigned() { + return false; +} +template <> +constexpr bool IsSigned() { + return false; +} + +template () && !IsIntegerLaneType()> +struct MakeLaneTypeIfIntegerT { + using type = T; +}; + +template +struct MakeLaneTypeIfIntegerT { + using type = hwy::If(), SignedFromSize, + UnsignedFromSize>; +}; + +template +using MakeLaneTypeIfInteger = typename MakeLaneTypeIfIntegerT::type; // Largest/smallest representable integer values. template HWY_API constexpr T LimitsMax() { - static_assert(!IsFloat(), "Only for integer types"); - using TU = MakeUnsigned; - return static_cast(IsSigned() ? (static_cast(~0ull) >> 1) - : static_cast(~0ull)); + static_assert(IsInteger(), "Only for integer types"); + using TU = UnsignedFromSize; + return static_cast(IsSigned() ? (static_cast(~TU(0)) >> 1) + : static_cast(~TU(0))); } template HWY_API constexpr T LimitsMin() { - static_assert(!IsFloat(), "Only for integer types"); - return IsSigned() ? T(-1) - LimitsMax() : T(0); + static_assert(IsInteger(), "Only for integer types"); + return IsSigned() ? static_cast(-1) - LimitsMax() + : static_cast(0); } // Largest/smallest representable value (integer or float). This naming avoids // confusion with numeric_limits::min() (the smallest positive value). // Cannot be constexpr because we use CopySameSize for [b]float16_t. template -HWY_API T LowestValue() { +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T LowestValue() { return LimitsMin(); } template <> -HWY_INLINE bfloat16_t LowestValue() { - const uint16_t kBits = 0xFF7F; // -1.1111111 x 2^127 - bfloat16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t LowestValue() { + return bfloat16_t::FromBits(uint16_t{0xFF7Fu}); // -1.1111111 x 2^127 } template <> -HWY_INLINE float16_t LowestValue() { - const uint16_t kBits = 0xFBFF; // -1.1111111111 x 2^15 - float16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t LowestValue() { + return float16_t::FromBits(uint16_t{0xFBFFu}); // -1.1111111111 x 2^15 } template <> -HWY_INLINE float LowestValue() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float LowestValue() { return -3.402823466e+38F; } template <> -HWY_INLINE double LowestValue() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double LowestValue() { return -1.7976931348623158e+308; } template -HWY_API T HighestValue() { +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T HighestValue() { return LimitsMax(); } template <> -HWY_INLINE bfloat16_t HighestValue() { - const uint16_t kBits = 0x7F7F; // 1.1111111 x 2^127 - bfloat16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t HighestValue() { + return bfloat16_t::FromBits(uint16_t{0x7F7Fu}); // 1.1111111 x 2^127 } template <> -HWY_INLINE float16_t HighestValue() { - const uint16_t kBits = 0x7BFF; // 1.1111111111 x 2^15 - float16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t HighestValue() { + return float16_t::FromBits(uint16_t{0x7BFFu}); // 1.1111111111 x 2^15 } template <> -HWY_INLINE float HighestValue() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float HighestValue() { return 3.402823466e+38F; } template <> -HWY_INLINE double HighestValue() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double HighestValue() { return 1.7976931348623158e+308; } // Difference between 1.0 and the next representable value. Equal to // 1 / (1ULL << MantissaBits()), but hard-coding ensures precision. template -HWY_API T Epsilon() { +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T Epsilon() { return 1; } template <> -HWY_INLINE bfloat16_t Epsilon() { - const uint16_t kBits = 0x3C00; // 0.0078125 - bfloat16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t Epsilon() { + return bfloat16_t::FromBits(uint16_t{0x3C00u}); // 0.0078125 } template <> -HWY_INLINE float16_t Epsilon() { - const uint16_t kBits = 0x1400; // 0.0009765625 - float16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t Epsilon() { + return float16_t::FromBits(uint16_t{0x1400u}); // 0.0009765625 } template <> -HWY_INLINE float Epsilon() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float Epsilon() { return 1.192092896e-7f; } template <> -HWY_INLINE double Epsilon() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double Epsilon() { return 2.2204460492503131e-16; } @@ -1278,7 +2364,8 @@ constexpr MakeUnsigned SignMask() { // Returns bitmask of the exponent field in IEEE binary16/32/64. template constexpr MakeUnsigned ExponentMask() { - return (~(MakeUnsigned{1} << MantissaBits()) + 1) & ~SignMask(); + return (~(MakeUnsigned{1} << MantissaBits()) + 1) & + static_cast>(~SignMask()); } // Returns bitmask of the mantissa field in IEEE binary16/32/64. @@ -1290,30 +2377,24 @@ constexpr MakeUnsigned MantissaMask() { // Returns 1 << mantissa_bits as a floating-point number. All integers whose // absolute value are less than this can be represented exactly. template -HWY_INLINE T MantissaEnd() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T MantissaEnd() { static_assert(sizeof(T) == 0, "Only instantiate the specializations"); return 0; } template <> -HWY_INLINE bfloat16_t MantissaEnd() { - const uint16_t kBits = 0x4300; // 1.0 x 2^7 - bfloat16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t MantissaEnd() { + return bfloat16_t::FromBits(uint16_t{0x4300u}); // 1.0 x 2^7 } template <> -HWY_INLINE float16_t MantissaEnd() { - const uint16_t kBits = 0x6400; // 1.0 x 2^10 - float16_t ret; - CopySameSize(&kBits, &ret); - return ret; +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t MantissaEnd() { + return float16_t::FromBits(uint16_t{0x6400u}); // 1.0 x 2^10 } template <> -HWY_INLINE float MantissaEnd() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float MantissaEnd() { return 8388608.0f; // 1 << 23 } template <> -HWY_INLINE double MantissaEnd() { +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double MantissaEnd() { // floating point literal with p52 requires C++17. return 4503599627370496.0; // 1 << 52 } @@ -1333,21 +2414,227 @@ constexpr MakeSigned MaxExponentField() { return (MakeSigned{1} << ExponentBits()) - 1; } +//------------------------------------------------------------------------------ +// Additional F16/BF16 operators + +#if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS + +#define HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T2) \ + template < \ + typename T1, \ + hwy::EnableIf>() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename RawResultT = decltype(DeclVal() op DeclVal()), \ + typename ResultT = detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ + return static_cast(a op b.native); \ + } + +#define HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(op, assign_op, T2) \ + template >() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename ResultT = \ + decltype(DeclVal() assign_op DeclVal())> \ + static HWY_INLINE constexpr ResultT operator assign_op(T1& a, \ + T2 b) noexcept { \ + return (a assign_op b.native); \ + } + +#define HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(op, op_func, T1) \ + HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T1) \ + template < \ + typename T2, \ + hwy::EnableIf>() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename RawResultT = decltype(DeclVal() op DeclVal()), \ + typename ResultT = detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ + return static_cast(a.native op b); \ + } + +#if HWY_HAVE_SCALAR_F16_OPERATORS +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(+, +=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(-, -=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(*, *=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(/, /=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, float16_t) +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, float16_t) +#endif +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(+, +=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(-, -=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(*, *=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(/, /=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, bfloat16_t) +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, bfloat16_t) +#endif +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS + +#undef HWY_RHS_SPECIAL_FLOAT_ARITH_OP +#undef HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP +#undef HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP + +#endif // HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS + +//------------------------------------------------------------------------------ +// Type conversions (after IsSpecialFloat) + +HWY_API float F32FromF16Mem(const void* ptr) { + float16_t f16; + CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &f16); + return F32FromF16(f16); +} + +HWY_API float F32FromBF16Mem(const void* ptr) { + bfloat16_t bf; + CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &bf); + return F32FromBF16(bf); +} + +#if HWY_HAVE_SCALAR_F16_OPERATORS +#define HWY_BF16_TO_F16_CONSTEXPR HWY_BF16_CONSTEXPR +#else +#define HWY_BF16_TO_F16_CONSTEXPR HWY_F16_CONSTEXPR +#endif + +// For casting from TFrom to TTo +template +HWY_API constexpr TTo ConvertScalarTo(const TFrom in) { + return static_cast(in); +} +template +HWY_API constexpr TTo ConvertScalarTo(const TFrom in) { + return F16FromF32(static_cast(in)); +} +template +HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo +ConvertScalarTo(const hwy::bfloat16_t in) { + return F16FromF32(F32FromBF16(in)); +} +template +HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const double in) { + return F16FromF64(in); +} +template +HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) { + return BF16FromF32(static_cast(in)); +} +template +HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo ConvertScalarTo(const hwy::float16_t in) { + return BF16FromF32(F32FromF16(in)); +} +template +HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const double in) { + return BF16FromF64(in); +} +template +HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) { + return static_cast(F32FromF16(in)); +} +template +HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(TFrom in) { + return static_cast(F32FromBF16(in)); +} +// Same: return unchanged +template +HWY_API constexpr TTo ConvertScalarTo(TTo in) { + return in; +} + //------------------------------------------------------------------------------ // Helper functions template constexpr inline T1 DivCeil(T1 a, T2 b) { +#if HWY_CXX_LANG >= 201703L + HWY_DASSERT(b != 0); +#endif return (a + b - 1) / b; } -// Works for any `align`; if a power of two, compiler emits ADD+AND. +// Works for any non-zero `align`; if a power of two, compiler emits ADD+AND. constexpr inline size_t RoundUpTo(size_t what, size_t align) { return DivCeil(what, align) * align; } +// Works for any `align`; if a power of two, compiler emits AND. +constexpr inline size_t RoundDownTo(size_t what, size_t align) { + return what - (what % align); +} + +namespace detail { + +// T is unsigned or T is signed and (val >> shift_amt) is an arithmetic right +// shift +template +static HWY_INLINE constexpr T ScalarShr(hwy::UnsignedTag /*type_tag*/, T val, + int shift_amt) { + return static_cast(val >> shift_amt); +} + +// T is signed and (val >> shift_amt) is a non-arithmetic right shift +template +static HWY_INLINE constexpr T ScalarShr(hwy::SignedTag /*type_tag*/, T val, + int shift_amt) { + using TU = MakeUnsigned>; + return static_cast( + (val < 0) ? static_cast( + ~(static_cast(~static_cast(val)) >> shift_amt)) + : static_cast(static_cast(val) >> shift_amt)); +} + +} // namespace detail + +// If T is an signed integer type, ScalarShr is guaranteed to perform an +// arithmetic right shift + +// Otherwise, if T is an unsigned integer type, ScalarShr is guaranteed to +// perform a logical right shift +template )> +HWY_API constexpr RemoveCvRef ScalarShr(T val, int shift_amt) { + using NonCvRefT = RemoveCvRef; + return detail::ScalarShr( + hwy::SizeTag<((IsSigned() && + (LimitsMin() >> (sizeof(T) * 8 - 1)) != + static_cast(-1)) + ? 0x100 + : 0)>(), + static_cast(val), shift_amt); +} + // Undefined results for x == 0. HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { + HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC unsigned long index; // NOLINT _BitScanForward(&index, x); @@ -1358,6 +2645,7 @@ HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { } HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { + HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC #if HWY_ARCH_X86_64 unsigned long index; // NOLINT @@ -1383,6 +2671,7 @@ HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { // Undefined results for x == 0. HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { + HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC unsigned long index; // NOLINT _BitScanReverse(&index, x); @@ -1393,6 +2682,7 @@ HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { } HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { + HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC #if HWY_ARCH_X86_64 unsigned long index; // NOLINT @@ -1416,26 +2706,48 @@ HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { #endif // HWY_COMPILER_MSVC } -HWY_API size_t PopCount(uint64_t x) { -#if HWY_COMPILER_GCC // includes clang - return static_cast(__builtin_popcountll(x)); - // This instruction has a separate feature flag, but is often called from - // non-SIMD code, so we don't want to require dynamic dispatch. It was first - // supported by Intel in Nehalem (SSE4.2), but MSVC only predefines a macro - // for AVX, so check for that. +template ), + HWY_IF_T_SIZE_ONE_OF(RemoveCvRef, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API size_t PopCount(T x) { + uint32_t u32_x = static_cast( + static_cast)>>(x)); + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + return static_cast(__builtin_popcountl(u32_x)); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return static_cast(_mm_popcnt_u32(u32_x)); +#else + u32_x -= ((u32_x >> 1) & 0x55555555u); + u32_x = (((u32_x >> 2) & 0x33333333u) + (u32_x & 0x33333333u)); + u32_x = (((u32_x >> 4) + u32_x) & 0x0F0F0F0Fu); + u32_x += (u32_x >> 8); + u32_x += (u32_x >> 16); + return static_cast(u32_x & 0x3Fu); +#endif +} + +template ), + HWY_IF_T_SIZE(RemoveCvRef, 8)> +HWY_API size_t PopCount(T x) { + uint64_t u64_x = static_cast( + static_cast)>>(x)); + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + return static_cast(__builtin_popcountll(u64_x)); #elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 && defined(__AVX__) - return _mm_popcnt_u64(x); + return _mm_popcnt_u64(u64_x); #elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) - return _mm_popcnt_u32(static_cast(x & 0xFFFFFFFFu)) + - _mm_popcnt_u32(static_cast(x >> 32)); + return _mm_popcnt_u32(static_cast(u64_x & 0xFFFFFFFFu)) + + _mm_popcnt_u32(static_cast(u64_x >> 32)); #else - x -= ((x >> 1) & 0x5555555555555555ULL); - x = (((x >> 2) & 0x3333333333333333ULL) + (x & 0x3333333333333333ULL)); - x = (((x >> 4) + x) & 0x0F0F0F0F0F0F0F0FULL); - x += (x >> 8); - x += (x >> 16); - x += (x >> 32); - return static_cast(x & 0x7Fu); + u64_x -= ((u64_x >> 1) & 0x5555555555555555ULL); + u64_x = (((u64_x >> 2) & 0x3333333333333333ULL) + + (u64_x & 0x3333333333333333ULL)); + u64_x = (((u64_x >> 4) + u64_x) & 0x0F0F0F0F0F0F0F0FULL); + u64_x += (u64_x >> 8); + u64_x += (u64_x >> 16); + u64_x += (u64_x >> 32); + return static_cast(u64_x & 0x7Fu); #endif } @@ -1456,21 +2768,32 @@ template : static_cast(FloorLog2(static_cast(x - 1)) + 1); } -template -HWY_INLINE constexpr T AddWithWraparound(hwy::FloatTag /*tag*/, T t, size_t n) { - return t + static_cast(n); +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { + return t + static_cast(increment); } -template -HWY_INLINE constexpr T AddWithWraparound(hwy::NonFloatTag /*tag*/, T t, - size_t n) { +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { + return ConvertScalarTo(ConvertScalarTo(t) + + ConvertScalarTo(increment)); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 n) { using TU = MakeUnsigned; - return static_cast( - static_cast(static_cast(t) + static_cast(n)) & - hwy::LimitsMax()); + // Sub-int types would promote to int, not unsigned, which would trigger + // warnings, so first promote to the largest unsigned type. Due to + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87519, which affected GCC 8 + // until fixed in 9.3, we use built-in types rather than uint64_t. + return static_cast(static_cast( + static_cast(static_cast(t) + + static_cast(n)) & + uint64_t{hwy::LimitsMax()})); } #if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 +#pragma intrinsic(_mul128) #pragma intrinsic(_umul128) #endif @@ -1494,7 +2817,179 @@ HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) { #endif } +HWY_API int64_t Mul128(int64_t a, int64_t b, int64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __int128_t product = (__int128_t)a * (__int128_t)b; + *upper = (int64_t)(product >> 64); + return (int64_t)(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _mul128(a, b, upper); +#else + uint64_t unsigned_upper; + const int64_t lower = static_cast(Mul128( + static_cast(a), static_cast(b), &unsigned_upper)); + *upper = static_cast( + unsigned_upper - + (static_cast(ScalarShr(a, 63)) & static_cast(b)) - + (static_cast(ScalarShr(b, 63)) & static_cast(a))); + return lower; +#endif +} + +// Precomputation for fast n / divisor and n % divisor, where n is a variable +// and divisor is unchanging but unknown at compile-time. +class Divisor { + public: + explicit Divisor(uint32_t divisor) : divisor_(divisor) { + if (divisor <= 1) return; + + const uint32_t len = + static_cast(31 - Num0BitsAboveMS1Bit_Nonzero32(divisor - 1)); + const uint64_t u_hi = (2ULL << len) - divisor; + const uint32_t q = Truncate((u_hi << 32) / divisor); + + mul_ = q + 1; + shift1_ = 1; + shift2_ = len; + } + + uint32_t GetDivisor() const { return divisor_; } + + // Returns n / divisor_. + uint32_t Divide(uint32_t n) const { + const uint64_t mul = mul_; + const uint32_t t = Truncate((mul * n) >> 32); + return (t + ((n - t) >> shift1_)) >> shift2_; + } + + // Returns n % divisor_. + uint32_t Remainder(uint32_t n) const { return n - (Divide(n) * divisor_); } + + private: + static uint32_t Truncate(uint64_t x) { + return static_cast(x & 0xFFFFFFFFu); + } + + uint32_t divisor_; + uint32_t mul_ = 1; + uint32_t shift1_ = 0; + uint32_t shift2_ = 0; +}; + +namespace detail { + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::FloatTag /*tag*/, + T val) { + using TU = MakeUnsigned; + return BitCastScalar( + static_cast(BitCastScalar(val) & (~SignMask()))); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::SpecialTag /*tag*/, T val) { + return ScalarAbs(hwy::FloatTag(), val); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::SignedTag /*tag*/, T val) { + using TU = MakeUnsigned; + return (val < T{0}) ? static_cast(TU{0} - static_cast(val)) : val; +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::UnsignedTag /*tag*/, T val) { + return val; +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarAbs(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + return detail::ScalarAbs(hwy::TypeTag(), static_cast(val)); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsNaN(T val) { + using TF = detail::NativeSpecialFloatToWrapper>; + using TU = MakeUnsigned; + return (BitCastScalar(ScalarAbs(val)) > ExponentMask()); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsInf(T val) { + using TF = detail::NativeSpecialFloatToWrapper>; + using TU = MakeUnsigned; + return static_cast(BitCastScalar(static_cast(val)) << 1) == + static_cast(MaxExponentTimes2()); +} + +namespace detail { + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( + hwy::FloatTag /*tag*/, T val) { + using TU = MakeUnsigned; + return (BitCastScalar(hwy::ScalarAbs(val)) < ExponentMask()); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( + hwy::NonFloatTag /*tag*/, T /*val*/) { + // Integer values are always finite + return true; +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + return detail::ScalarIsFinite(hwy::IsFloatTag(), + static_cast(val)); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarCopySign(T magn, + T sign) { + using TF = RemoveCvRef>>; + using TU = MakeUnsigned; + return BitCastScalar(static_cast( + (BitCastScalar(static_cast(magn)) & (~SignMask())) | + (BitCastScalar(static_cast(sign)) & SignMask()))); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarSignBit(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + using TU = MakeUnsigned; + return ((BitCastScalar(static_cast(val)) & SignMask()) != 0); +} + // Prevents the compiler from eliding the computations that led to "output". +#if HWY_ARCH_PPC && (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \ + !defined(_SOFT_FLOAT) +// Workaround to avoid test failures on PPC if compiled with Clang +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+f"(output)::"memory"); +} +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+d"(output)::"memory"); +} +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+r"(output)::"memory"); +} +#else template HWY_API void PreventElision(T&& output) { #if HWY_COMPILER_MSVC @@ -1502,8 +2997,8 @@ HWY_API void PreventElision(T&& output) { // RTL constraints). Self-assignment with #pragma optimize("off") might be // expected to prevent elision, but it does not with MSVC 2015. Type-punning // with volatile pointers generates inefficient code on MSVC 2017. - static std::atomic> dummy; - dummy.store(output, std::memory_order_relaxed); + static std::atomic> sink; + sink.store(output, std::memory_order_relaxed); #else // Works by indicating to the compiler that "output" is being read and // modified. The +r constraint avoids unnecessary writes to memory, but only @@ -1511,6 +3006,7 @@ HWY_API void PreventElision(T&& output) { asm volatile("" : "+r"(output) : : "memory"); #endif } +#endif } // namespace hwy diff --git a/r/src/vendor/highway/hwy/bit_set.h b/r/src/vendor/highway/hwy/bit_set.h new file mode 100644 index 00000000..f8f921be --- /dev/null +++ b/r/src/vendor/highway/hwy/bit_set.h @@ -0,0 +1,158 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BIT_SET_H_ +#define HIGHWAY_HWY_BIT_SET_H_ + +// BitSet with fast Foreach for up to 64 and 4096 members. + +#include + +#include "hwy/base.h" + +namespace hwy { + +// 64-bit specialization of std::bitset, which lacks Foreach. +class BitSet64 { + public: + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < 64); + bits_ |= (1ULL << i); + HWY_DASSERT(Get(i)); + } + + // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does + // not clear any existing bits. + void SetNonzeroBitsFrom64(uint64_t bits) { bits_ |= bits; } + + void Clear(size_t i) { + HWY_DASSERT(i < 64); + bits_ &= ~(1ULL << i); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < 64); + return (bits_ & (1ULL << i)) != 0; + } + + // Returns true if any Get(i) would return true for i in [0, 64). + bool Any() const { return bits_ != 0; } + + // Returns lowest i such that Get(i). Caller must ensure Any() beforehand! + size_t First() const { + HWY_DASSERT(Any()); + return Num0BitsBelowLS1Bit_Nonzero64(bits_); + } + + // Returns uint64_t(Get(i)) << i for i in [0, 64). + uint64_t Get64() const { return bits_; } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is unaffected. + template + void Foreach(const Func& func) const { + uint64_t remaining_bits = bits_; + while (remaining_bits != 0) { + const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); + remaining_bits &= remaining_bits - 1; // clear LSB + func(i); + } + } + + size_t Count() const { return PopCount(bits_); } + + private: + uint64_t bits_ = 0; +}; + +// Two-level bitset for up to kMaxSize <= 4096 values. +template +class BitSet4096 { + public: + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Set(mod); + nonzero_.Set(idx); + HWY_DASSERT(Get(i)); + } + + // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does + // not clear any existing bits. + void SetNonzeroBitsFrom64(uint64_t bits) { + bits_[0].SetNonzeroBitsFrom64(bits); + if (bits) nonzero_.Set(0); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Clear(mod); + if (!bits_[idx].Any()) { + nonzero_.Clear(idx); + } + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return bits_[idx].Get(mod); + } + + // Returns true if any Get(i) would return true for i in [0, 64). + bool Any() const { return nonzero_.Any(); } + + // Returns lowest i such that Get(i). Caller must ensure Any() beforehand! + size_t First() const { + HWY_DASSERT(Any()); + const size_t idx = nonzero_.First(); + return idx * 64 + bits_[idx].First(); + } + + // Returns uint64_t(Get(i)) << i for i in [0, 64). + uint64_t Get64() const { return bits_[0].Get64(); } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited BitSet64 for which Any() is true. + template + void Foreach(const Func& func) const { + nonzero_.Foreach([&func, this](size_t idx) { + bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); + }); + } + + size_t Count() const { + size_t total = 0; + nonzero_.Foreach( + [&total, this](size_t idx) { total += bits_[idx].Count(); }); + return total; + } + + private: + static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient"); + BitSet64 nonzero_; + BitSet64 bits_[kMaxSize / 64]; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_BIT_SET_H_ diff --git a/r/src/vendor/highway/hwy/cache_control.h b/r/src/vendor/highway/hwy/cache_control.h index 6e7665dd..bdfa9599 100644 --- a/r/src/vendor/highway/hwy/cache_control.h +++ b/r/src/vendor/highway/hwy/cache_control.h @@ -25,11 +25,15 @@ #define HWY_DISABLE_CACHE_CONTROL #endif +#ifndef HWY_DISABLE_CACHE_CONTROL // intrin.h is sufficient on MSVC and already included by base.h. -#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#if HWY_ARCH_X86 && !HWY_COMPILER_MSVC #include // SSE2 #include // _mm_prefetch +#elif HWY_ARCH_ARM_A64 +#include #endif +#endif // HWY_DISABLE_CACHE_CONTROL namespace hwy { @@ -76,15 +80,16 @@ HWY_INLINE HWY_ATTR_CACHE void FlushStream() { // subsequent actual loads. template HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) { -#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + (void)p; +#ifndef HWY_DISABLE_CACHE_CONTROL +#if HWY_ARCH_X86 _mm_prefetch(reinterpret_cast(p), _MM_HINT_T0); #elif HWY_COMPILER_GCC // includes clang // Hint=0 (NTA) behavior differs, but skipping outer caches is probably not // desirable, so use the default 3 (keep in caches). __builtin_prefetch(p, /*write=*/0, /*hint=*/3); -#else - (void)p; #endif +#endif // HWY_DISABLE_CACHE_CONTROL } // Invalidates and flushes the cache line containing "p", if possible. @@ -96,11 +101,24 @@ HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) { #endif } -// When called inside a spin-loop, may reduce power consumption. +// Hints that we are inside a spin loop and potentially reduces power +// consumption and coherency traffic. For example, x86 avoids multiple +// outstanding load requests, which reduces the memory order violation penalty +// when exiting the loop. HWY_INLINE HWY_ATTR_CACHE void Pause() { -#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) +#ifndef HWY_DISABLE_CACHE_CONTROL +#if HWY_ARCH_X86 _mm_pause(); +#elif HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG + // This is documented in ACLE and the YIELD instruction is also available in + // Armv7, but the intrinsic is broken for Armv7 clang, hence A64 only. + __yield(); +#elif HWY_ARCH_ARM && HWY_COMPILER_GCC // includes clang + __asm__ volatile("yield" ::: "memory"); +#elif HWY_ARCH_PPC && HWY_COMPILER_GCC // includes clang + __asm__ volatile("or 27,27,27" ::: "memory"); #endif +#endif // HWY_DISABLE_CACHE_CONTROL } } // namespace hwy diff --git a/r/src/vendor/highway/hwy/contrib/algo/copy-inl.h b/r/src/vendor/highway/hwy/contrib/algo/copy-inl.h index 22f4252c..9945132f 100644 --- a/r/src/vendor/highway/hwy/contrib/algo/copy-inl.h +++ b/r/src/vendor/highway/hwy/contrib/algo/copy-inl.h @@ -14,13 +14,17 @@ // limitations under the License. // Per-target include guard -#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == defined(HWY_TARGET_TOGGLE) +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT #ifdef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ #undef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ #else #define HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ #endif +#include +#include + #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); @@ -40,8 +44,10 @@ void Fill(D d, T value, size_t count, T* HWY_RESTRICT to) { const Vec v = Set(d, value); size_t idx = 0; - for (; idx + N <= count; idx += N) { - StoreU(v, d, to + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + StoreU(v, d, to + idx); + } } // `count` was a multiple of the vector length `N`: already done. @@ -58,9 +64,11 @@ void Copy(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to) { const size_t N = Lanes(d); size_t idx = 0; - for (; idx + N <= count; idx += N) { - const Vec v = LoadU(d, from + idx); - StoreU(v, d, to + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, from + idx); + StoreU(v, d, to + idx); + } } // `count` was a multiple of the vector length `N`: already done. @@ -89,9 +97,11 @@ T* CopyIf(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to, const size_t N = Lanes(d); size_t idx = 0; - for (; idx + N <= count; idx += N) { - const Vec v = LoadU(d, from + idx); - to += CompressBlendedStore(v, func(d, v), d, to); + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, from + idx); + to += CompressBlendedStore(v, func(d, v), d, to); + } } // `count` was a multiple of the vector length `N`: already done. diff --git a/r/src/vendor/highway/hwy/contrib/algo/copy_test.cc b/r/src/vendor/highway/hwy/contrib/algo/copy_test.cc index c74f6e9b..054a01ce 100644 --- a/r/src/vendor/highway/hwy/contrib/algo/copy_test.cc +++ b/r/src/vendor/highway/hwy/contrib/algo/copy_test.cc @@ -13,6 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "hwy/aligned_allocator.h" // clang-format off @@ -35,11 +37,12 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +namespace { // Returns random integer in [0, 128), which fits in any lane type. template T Random7Bit(RandomState& rng) { - return static_cast(Random32(&rng) & 127); + return ConvertScalarTo(Random32(&rng) & 127); } // In C++14, we can instead define these as generic lambdas next to where they @@ -92,9 +95,9 @@ struct TestFill { } T* actual = pb.get() + misalign_b; - actual[count] = T{0}; // sentinel + actual[count] = ConvertScalarTo(0); // sentinel Fill(d, value, count, actual); - HWY_ASSERT_EQ(T{0}, actual[count]); // did not write past end + HWY_ASSERT_EQ(ConvertScalarTo(0), actual[count]); // no write past end const auto info = hwy::detail::MakeTypeInfo(); const char* target_name = hwy::TargetName(HWY_TARGET); @@ -187,18 +190,21 @@ void TestAllCopyIf() { ForUI163264(ForPartialVectors>()); } +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #if HWY_ONCE - namespace hwy { +namespace { HWY_BEFORE_TEST(CopyTest); HWY_EXPORT_AND_TEST_P(CopyTest, TestAllFill); HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopy); HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopyIf); +HWY_AFTER_TEST(); +} // namespace } // namespace hwy - -#endif +HWY_TEST_MAIN(); +#endif // HWY_ONCE diff --git a/r/src/vendor/highway/hwy/contrib/algo/find-inl.h b/r/src/vendor/highway/hwy/contrib/algo/find-inl.h index c1e5a843..dc0a8cac 100644 --- a/r/src/vendor/highway/hwy/contrib/algo/find-inl.h +++ b/r/src/vendor/highway/hwy/contrib/algo/find-inl.h @@ -14,7 +14,8 @@ // limitations under the License. // Per-target include guard -#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == defined(HWY_TARGET_TOGGLE) +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT #ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ #undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ #else @@ -35,9 +36,11 @@ size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { const Vec broadcasted = Set(d, value); size_t i = 0; - for (; i + N <= count; i += N) { - const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); - if (pos >= 0) return i + static_cast(pos); + if (count >= N) { + for (; i <= count - N; i += N) { + const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } } if (i != count) { @@ -72,9 +75,11 @@ size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { const size_t N = Lanes(d); size_t i = 0; - for (; i + N <= count; i += N) { - const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); - if (pos >= 0) return i + static_cast(pos); + if (count >= N) { + for (; i <= count - N; i += N) { + const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } } if (i != count) { diff --git a/r/src/vendor/highway/hwy/contrib/algo/find_test.cc b/r/src/vendor/highway/hwy/contrib/algo/find_test.cc index 3c7c1363..8593b60e 100644 --- a/r/src/vendor/highway/hwy/contrib/algo/find_test.cc +++ b/r/src/vendor/highway/hwy/contrib/algo/find_test.cc @@ -42,6 +42,7 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +namespace { // Returns random number in [-8, 8] - we use knowledge of the range to Find() // values we know are not present. @@ -53,7 +54,7 @@ T Random(RandomState& rng) { if (!hwy::IsSigned() && val < 0.0) { val = -val; } - return static_cast(val); + return ConvertScalarTo(val); } // In C++14, we can instead define these as generic lambdas next to where they @@ -65,7 +66,7 @@ class GreaterThan { GreaterThan(int val) : val_(val) {} template Mask operator()(D d, V v) const { - return Gt(v, Set(d, static_cast>(val_))); + return Gt(v, Set(d, ConvertScalarTo>(val_))); } private: @@ -121,15 +122,15 @@ struct TestFind { if (!IsEqual(in[pos], in[actual])) { fprintf(stderr, "%s count %d, found %.15f at %d but wanted %.15f\n", hwy::TypeName(T(), Lanes(d)).c_str(), static_cast(count), - static_cast(in[actual]), static_cast(actual), - static_cast(in[pos])); + ConvertScalarTo(in[actual]), static_cast(actual), + ConvertScalarTo(in[pos])); HWY_ASSERT(false); } for (size_t i = 0; i < actual; ++i) { if (IsEqual(in[i], in[pos])) { fprintf(stderr, "%s count %d, found %f at %d but Find returned %d\n", hwy::TypeName(T(), Lanes(d)).c_str(), static_cast(count), - static_cast(in[i]), static_cast(i), + ConvertScalarTo(in[i]), static_cast(i), static_cast(actual)); HWY_ASSERT(false); } @@ -137,8 +138,8 @@ struct TestFind { } // Also search for values we know not to be present (out of range) - HWY_ASSERT_EQ(count, Find(d, T{9}, in, count)); - HWY_ASSERT_EQ(count, Find(d, static_cast(-9), in, count)); + HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo(9), in, count)); + HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo(-9), in, count)); } }; @@ -158,8 +159,8 @@ struct TestFindIf { T* in = storage.get() + misalign; for (size_t i = 0; i < count; ++i) { in[i] = Random(rng); - HWY_ASSERT(static_cast(in[i]) <= 8); - HWY_ASSERT(!hwy::IsSigned() || static_cast(in[i]) >= -8); + HWY_ASSERT(ConvertScalarTo(in[i]) <= 8); + HWY_ASSERT(!hwy::IsSigned() || ConvertScalarTo(in[i]) >= -8); } bool found_any = false; @@ -173,7 +174,7 @@ struct TestFindIf { for (int val = min_val; val <= 9; ++val) { #if HWY_GENERIC_LAMBDA const auto greater = [val](const auto d, const auto v) HWY_ATTR { - return Gt(v, Set(d, static_cast(val))); + return Gt(v, Set(d, ConvertScalarTo(val))); }; #else const GreaterThan greater(val); @@ -183,7 +184,7 @@ struct TestFindIf { not_found_any |= actual == count; const auto pos = std::find_if( - in, in + count, [val](T x) { return x > static_cast(val); }); + in, in + count, [val](T x) { return x > ConvertScalarTo(val); }); // Convert returned iterator to index. const size_t expected = static_cast(pos - in); if (expected != actual) { @@ -200,7 +201,7 @@ struct TestFindIf { HWY_ASSERT(not_found_any); // We'll find something unless the input is empty or {0} - because 0 > i // is false for all i=[0,9]. - if (count != 0 && in[0] != T{0}) { + if (count != 0 && in[0] != ConvertScalarTo(0)) { HWY_ASSERT(found_any); } } @@ -210,17 +211,20 @@ void TestAllFindIf() { ForAllTypes(ForPartialVectors>()); } +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #if HWY_ONCE - namespace hwy { +namespace { HWY_BEFORE_TEST(FindTest); HWY_EXPORT_AND_TEST_P(FindTest, TestAllFind); HWY_EXPORT_AND_TEST_P(FindTest, TestAllFindIf); +HWY_AFTER_TEST(); +} // namespace } // namespace hwy - -#endif +HWY_TEST_MAIN(); +#endif // HWY_ONCE diff --git a/r/src/vendor/highway/hwy/contrib/algo/transform-inl.h b/r/src/vendor/highway/hwy/contrib/algo/transform-inl.h index 3e830acb..f48476d1 100644 --- a/r/src/vendor/highway/hwy/contrib/algo/transform-inl.h +++ b/r/src/vendor/highway/hwy/contrib/algo/transform-inl.h @@ -22,6 +22,8 @@ #define HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ #endif +#include + #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); @@ -35,13 +37,12 @@ namespace HWY_NAMESPACE { // would be more verbose than such a loop. // // Func is either a functor with a templated operator()(d, v[, v1[, v2]]), or a -// generic lambda if using C++14. Due to apparent limitations of Clang on -// Windows, it is currently necessary to add HWY_ATTR before the opening { of -// the lambda to avoid errors about "always_inline function .. requires target". +// generic lambda if using C++14. The d argument is the same as was passed to +// the Generate etc. functions. Due to apparent limitations of Clang, it is +// currently necessary to add HWY_ATTR before the opening { of the lambda to +// avoid errors about "always_inline function .. requires target". // -// If HWY_MEM_OPS_MIGHT_FAULT, we use scalar code instead of masking. Otherwise, -// we used `MaskedLoad` and `BlendedStore` to read/write the final partial -// vector. +// We do not check HWY_MEM_OPS_MIGHT_FAULT because LoadN/StoreN do not fault. // Fills `out[0, count)` with the vectors returned by `func(d, index_vec)`, // where `index_vec` is `Vec>`. On the first call to `func`, @@ -56,27 +57,43 @@ void Generate(D d, T* HWY_RESTRICT out, size_t count, const Func& func) { size_t idx = 0; Vec vidx = Iota(du, 0); - for (; idx + N <= count; idx += N) { - StoreU(func(d, vidx), d, out + idx); - vidx = Add(vidx, Set(du, static_cast(N))); + if (count >= N) { + for (; idx <= count - N; idx += N) { + StoreU(func(d, vidx), d, out + idx); + vidx = Add(vidx, Set(du, static_cast(N))); + } } // `count` was a multiple of the vector length `N`: already done. if (HWY_UNLIKELY(idx == count)) return; -#if HWY_MEM_OPS_MIGHT_FAULT - // Proceed one by one. - const CappedTag d1; - const RebindToUnsigned du1; - for (; idx < count; ++idx) { - StoreU(func(d1, Set(du1, static_cast(idx))), d1, out + idx); + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + StoreN(func(d, vidx), d, out + idx, remaining); +} + +// Calls `func(d, v)` for each input vector; out of bound lanes with index i >= +// `count` are instead taken from `no[i % Lanes(d)]`. +template > +void Foreach(D d, const T* HWY_RESTRICT in, const size_t count, const Vec no, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, in + idx); + func(d, v); + } } -#else + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + const size_t remaining = count - idx; HWY_DASSERT(0 != remaining && remaining < N); - const Mask mask = FirstN(d, remaining); - BlendedStore(func(d, vidx), mask, d, out + idx); -#endif + const Vec v = LoadNOr(no, d, in + idx, remaining); + func(d, v); } // Replaces `inout[idx]` with `func(d, inout[idx])`. Example usage: multiplying @@ -86,29 +103,20 @@ void Transform(D d, T* HWY_RESTRICT inout, size_t count, const Func& func) { const size_t N = Lanes(d); size_t idx = 0; - for (; idx + N <= count; idx += N) { - const Vec v = LoadU(d, inout + idx); - StoreU(func(d, v), d, inout + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + StoreU(func(d, v), d, inout + idx); + } } // `count` was a multiple of the vector length `N`: already done. if (HWY_UNLIKELY(idx == count)) return; -#if HWY_MEM_OPS_MIGHT_FAULT - // Proceed one by one. - const CappedTag d1; - for (; idx < count; ++idx) { - using V1 = Vec; - const V1 v = LoadU(d1, inout + idx); - StoreU(func(d1, v), d1, inout + idx); - } -#else const size_t remaining = count - idx; HWY_DASSERT(0 != remaining && remaining < N); - const Mask mask = FirstN(d, remaining); - const Vec v = MaskedLoad(mask, d, inout + idx); - BlendedStore(func(d, v), mask, d, inout + idx); -#endif + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(func(d, v), d, inout + idx, remaining); } // Replaces `inout[idx]` with `func(d, inout[idx], in1[idx])`. Example usage: @@ -119,32 +127,22 @@ void Transform1(D d, T* HWY_RESTRICT inout, size_t count, const size_t N = Lanes(d); size_t idx = 0; - for (; idx + N <= count; idx += N) { - const Vec v = LoadU(d, inout + idx); - const Vec v1 = LoadU(d, in1 + idx); - StoreU(func(d, v, v1), d, inout + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + StoreU(func(d, v, v1), d, inout + idx); + } } // `count` was a multiple of the vector length `N`: already done. if (HWY_UNLIKELY(idx == count)) return; -#if HWY_MEM_OPS_MIGHT_FAULT - // Proceed one by one. - const CappedTag d1; - for (; idx < count; ++idx) { - using V1 = Vec; - const V1 v = LoadU(d1, inout + idx); - const V1 v1 = LoadU(d1, in1 + idx); - StoreU(func(d1, v, v1), d1, inout + idx); - } -#else const size_t remaining = count - idx; HWY_DASSERT(0 != remaining && remaining < N); - const Mask mask = FirstN(d, remaining); - const Vec v = MaskedLoad(mask, d, inout + idx); - const Vec v1 = MaskedLoad(mask, d, in1 + idx); - BlendedStore(func(d, v, v1), mask, d, inout + idx); -#endif + const Vec v = LoadN(d, inout + idx, remaining); + const Vec v1 = LoadN(d, in1 + idx, remaining); + StoreN(func(d, v, v1), d, inout + idx, remaining); } // Replaces `inout[idx]` with `func(d, inout[idx], in1[idx], in2[idx])`. Example @@ -156,35 +154,24 @@ void Transform2(D d, T* HWY_RESTRICT inout, size_t count, const size_t N = Lanes(d); size_t idx = 0; - for (; idx + N <= count; idx += N) { - const Vec v = LoadU(d, inout + idx); - const Vec v1 = LoadU(d, in1 + idx); - const Vec v2 = LoadU(d, in2 + idx); - StoreU(func(d, v, v1, v2), d, inout + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + const Vec v2 = LoadU(d, in2 + idx); + StoreU(func(d, v, v1, v2), d, inout + idx); + } } // `count` was a multiple of the vector length `N`: already done. if (HWY_UNLIKELY(idx == count)) return; -#if HWY_MEM_OPS_MIGHT_FAULT - // Proceed one by one. - const CappedTag d1; - for (; idx < count; ++idx) { - using V1 = Vec; - const V1 v = LoadU(d1, inout + idx); - const V1 v1 = LoadU(d1, in1 + idx); - const V1 v2 = LoadU(d1, in2 + idx); - StoreU(func(d1, v, v1, v2), d1, inout + idx); - } -#else const size_t remaining = count - idx; HWY_DASSERT(0 != remaining && remaining < N); - const Mask mask = FirstN(d, remaining); - const Vec v = MaskedLoad(mask, d, inout + idx); - const Vec v1 = MaskedLoad(mask, d, in1 + idx); - const Vec v2 = MaskedLoad(mask, d, in2 + idx); - BlendedStore(func(d, v, v1, v2), mask, d, inout + idx); -#endif + const Vec v = LoadN(d, inout + idx, remaining); + const Vec v1 = LoadN(d, in1 + idx, remaining); + const Vec v2 = LoadN(d, in2 + idx, remaining); + StoreN(func(d, v, v1, v2), d, inout + idx, remaining); } template > @@ -194,31 +181,20 @@ void Replace(D d, T* HWY_RESTRICT inout, size_t count, T new_t, T old_t) { const Vec new_v = Set(d, new_t); size_t idx = 0; - for (; idx + N <= count; idx += N) { - Vec v = LoadU(d, inout + idx); - StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx); + } } // `count` was a multiple of the vector length `N`: already done. if (HWY_UNLIKELY(idx == count)) return; -#if HWY_MEM_OPS_MIGHT_FAULT - // Proceed one by one. - const CappedTag d1; - const Vec old_v1 = Set(d1, old_t); - const Vec new_v1 = Set(d1, new_t); - for (; idx < count; ++idx) { - using V1 = Vec; - const V1 v1 = LoadU(d1, inout + idx); - StoreU(IfThenElse(Eq(v1, old_v1), new_v1, v1), d1, inout + idx); - } -#else const size_t remaining = count - idx; HWY_DASSERT(0 != remaining && remaining < N); - const Mask mask = FirstN(d, remaining); - const Vec v = MaskedLoad(mask, d, inout + idx); - BlendedStore(IfThenElse(Eq(v, old_v), new_v, v), mask, d, inout + idx); -#endif + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx, remaining); } template > @@ -228,30 +204,20 @@ void ReplaceIf(D d, T* HWY_RESTRICT inout, size_t count, T new_t, const Vec new_v = Set(d, new_t); size_t idx = 0; - for (; idx + N <= count; idx += N) { - Vec v = LoadU(d, inout + idx); - StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx); + if (count >= N) { + for (; idx <= count - N; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx); + } } // `count` was a multiple of the vector length `N`: already done. if (HWY_UNLIKELY(idx == count)) return; -#if HWY_MEM_OPS_MIGHT_FAULT - // Proceed one by one. - const CappedTag d1; - const Vec new_v1 = Set(d1, new_t); - for (; idx < count; ++idx) { - using V1 = Vec; - const V1 v = LoadU(d1, inout + idx); - StoreU(IfThenElse(func(d1, v), new_v1, v), d1, inout + idx); - } -#else const size_t remaining = count - idx; HWY_DASSERT(0 != remaining && remaining < N); - const Mask mask = FirstN(d, remaining); - const Vec v = MaskedLoad(mask, d, inout + idx); - BlendedStore(IfThenElse(func(d, v), new_v, v), mask, d, inout + idx); -#endif + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(IfThenElse(func(d, v), new_v, v), d, inout + idx, remaining); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/r/src/vendor/highway/hwy/contrib/algo/transform_test.cc b/r/src/vendor/highway/hwy/contrib/algo/transform_test.cc index 9ac87ea6..fc4fd16a 100644 --- a/r/src/vendor/highway/hwy/contrib/algo/transform_test.cc +++ b/r/src/vendor/highway/hwy/contrib/algo/transform_test.cc @@ -18,6 +18,7 @@ #include #include "hwy/aligned_allocator.h" +#include "hwy/base.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -39,6 +40,7 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +namespace { constexpr double kAlpha = 1.5; // arbitrary scalar @@ -49,22 +51,23 @@ T Random(RandomState& rng) { const int32_t bits = static_cast(Random32(&rng)) & 1023; const double val = (bits - 512) / 64.0; // Clamp negative to zero for unsigned types. - return static_cast( - HWY_MAX(static_cast(hwy::LowestValue()), val)); + return ConvertScalarTo( + HWY_MAX(ConvertScalarTo(hwy::LowestValue()), val)); } // SCAL, AXPY names are from BLAS. template HWY_NOINLINE void SimpleSCAL(const T* x, T* out, size_t count) { for (size_t i = 0; i < count; ++i) { - out[i] = static_cast(kAlpha * x[i]); + out[i] = ConvertScalarTo(ConvertScalarTo(kAlpha) * x[i]); } } template HWY_NOINLINE void SimpleAXPY(const T* x, const T* y, T* out, size_t count) { for (size_t i = 0; i < count; ++i) { - out[i] = static_cast(kAlpha * x[i] + y[i]); + out[i] = ConvertScalarTo( + ConvertScalarTo(ConvertScalarTo(kAlpha) * x[i]) + y[i]); } } @@ -72,7 +75,7 @@ template HWY_NOINLINE void SimpleFMA4(const T* x, const T* y, const T* z, T* out, size_t count) { for (size_t i = 0; i < count; ++i) { - out[i] = static_cast(x[i] * y[i] + z[i]); + out[i] = ConvertScalarTo(x[i] * y[i] + z[i]); } } @@ -92,7 +95,7 @@ struct SCAL { template Vec operator()(D d, V v) const { using T = TFromD; - return Mul(Set(d, static_cast(kAlpha)), v); + return Mul(Set(d, ConvertScalarTo(kAlpha)), v); } }; @@ -100,7 +103,7 @@ struct AXPY { template Vec operator()(D d, V v, V v1) const { using T = TFromD; - return MulAdd(Set(d, static_cast(kAlpha)), v, v1); + return MulAdd(Set(d, ConvertScalarTo(kAlpha)), v, v1); } }; @@ -133,6 +136,24 @@ struct ForeachCountAndMisalign { } }; +// Fills an array with random values, placing a given sentinel value both before +// (when misalignment space is available) and after. Requires an allocation of +// at least count + misalign + 1 elements. +template +T* FillRandom(AlignedFreeUniquePtr& pa, size_t count, size_t misalign, + T sentinel, RandomState& rng) { + for (size_t i = 0; i < misalign; ++i) { + pa[i] = sentinel; + } + + T* a = pa.get() + misalign; + for (size_t i = 0; i < count; ++i) { + a[i] = Random(rng); + } + a[count] = sentinel; + return a; +} + // Output-only, no loads struct TestGenerate { template @@ -146,7 +167,7 @@ struct TestGenerate { T* actual = pa.get() + misalign_a; for (size_t i = 0; i < count; ++i) { - expected[i] = static_cast(2 * i); + expected[i] = ConvertScalarTo(2 * i); } // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that @@ -157,9 +178,9 @@ struct TestGenerate { #else const Gen2 gen2; #endif - actual[count] = T{0}; // sentinel + actual[count] = ConvertScalarTo(0); // sentinel Generate(d, actual, count, gen2); - HWY_ASSERT_EQ(T{0}, actual[count]); // did not write past end + HWY_ASSERT_EQ(ConvertScalarTo(0), actual[count]); // no write past end const auto info = hwy::detail::MakeTypeInfo(); const char* target_name = hwy::TargetName(HWY_TARGET); @@ -168,6 +189,42 @@ struct TestGenerate { } }; +// Input-only, no stores +struct TestForeach { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& /*rng*/) { + if (misalign_b != 0) return; + using T = TFromD; + AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + count + 1); + HWY_ASSERT(pa); + + T* actual = pa.get() + misalign_a; + T max = hwy::LowestValue(); + for (size_t i = 0; i < count; ++i) { + actual[i] = hwy::ConvertScalarTo(i <= count / 2 ? 2 * i : i); + max = HWY_MAX(max, actual[i]); + } + + // Place sentinel values in the misalignment area and at the input's end. + for (size_t i = 0; i < misalign_a; ++i) { + pa[i] = ConvertScalarTo(2 * count); + } + actual[count] = ConvertScalarTo(2 * count); + + const Vec vmin = Set(d, hwy::LowestValue()); + // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that + // the attribute also applies to lambdas? If so, remove HWY_ATTR. + Vec vmax = vmin; + const auto func = [&vmax](const D, const Vec v) + HWY_ATTR { vmax = Max(vmax, v); }; + Foreach(d, actual, count, vmin, func); + + const char* target_name = hwy::TargetName(HWY_TARGET); + AssertEqual(max, ReduceMax(d, vmax), target_name, __FILE__, __LINE__); + } +}; + // Zero extra input arrays struct TestTransform { template @@ -177,22 +234,19 @@ struct TestTransform { using T = TFromD; // Prevents error if size to allocate is zero. AlignedFreeUniquePtr pa = - AllocateAligned(HWY_MAX(1, misalign_a + count)); + AllocateAligned(HWY_MAX(1, misalign_a + count + 1)); AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); HWY_ASSERT(pa && expected); - T* a = pa.get() + misalign_a; - for (size_t i = 0; i < count; ++i) { - a[i] = Random(rng); - } - + const T sentinel = ConvertScalarTo(-42); + T* a = FillRandom(pa, count, misalign_a, sentinel, rng); SimpleSCAL(a, expected.get(), count); // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that // the attribute also applies to lambdas? If so, remove HWY_ATTR. #if HWY_GENERIC_LAMBDA const auto scal = [](const auto d, const auto v) HWY_ATTR { - return Mul(Set(d, static_cast(kAlpha)), v); + return Mul(Set(d, ConvertScalarTo(kAlpha)), v); }; #else const SCAL scal; @@ -203,6 +257,12 @@ struct TestTransform { const char* target_name = hwy::TargetName(HWY_TARGET); hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, __FILE__, __LINE__); + + // Ensure no out-of-bound writes. + for (size_t i = 0; i < misalign_a; ++i) { + HWY_ASSERT_EQ(sentinel, pa[i]); + } + HWY_ASSERT_EQ(sentinel, a[count]); } }; @@ -214,15 +274,16 @@ struct TestTransform1 { using T = TFromD; // Prevents error if size to allocate is zero. AlignedFreeUniquePtr pa = - AllocateAligned(HWY_MAX(1, misalign_a + count)); + AllocateAligned(HWY_MAX(1, misalign_a + count + 1)); AlignedFreeUniquePtr pb = AllocateAligned(HWY_MAX(1, misalign_b + count)); AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); HWY_ASSERT(pa && pb && expected); - T* a = pa.get() + misalign_a; + + const T sentinel = ConvertScalarTo(-42); + T* a = FillRandom(pa, count, misalign_a, sentinel, rng); T* b = pb.get() + misalign_b; for (size_t i = 0; i < count; ++i) { - a[i] = Random(rng); b[i] = Random(rng); } @@ -230,17 +291,20 @@ struct TestTransform1 { #if HWY_GENERIC_LAMBDA const auto axpy = [](const auto d, const auto v, const auto v1) HWY_ATTR { - return MulAdd(Set(d, static_cast(kAlpha)), v, v1); + return MulAdd(Set(d, ConvertScalarTo(kAlpha)), v, v1); }; #else const AXPY axpy; #endif Transform1(d, a, count, b, axpy); - const auto info = hwy::detail::MakeTypeInfo(); - const char* target_name = hwy::TargetName(HWY_TARGET); - hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, - __FILE__, __LINE__); + AssertArraySimilar(expected.get(), a, count, hwy::TargetName(HWY_TARGET), + __FILE__, __LINE__); + // Ensure no out-of-bound writes. + for (size_t i = 0; i < misalign_a; ++i) { + HWY_ASSERT_EQ(sentinel, pa[i]); + } + HWY_ASSERT_EQ(sentinel, a[count]); } }; @@ -252,18 +316,19 @@ struct TestTransform2 { using T = TFromD; // Prevents error if size to allocate is zero. AlignedFreeUniquePtr pa = - AllocateAligned(HWY_MAX(1, misalign_a + count)); + AllocateAligned(HWY_MAX(1, misalign_a + count + 1)); AlignedFreeUniquePtr pb = AllocateAligned(HWY_MAX(1, misalign_b + count)); AlignedFreeUniquePtr pc = AllocateAligned(HWY_MAX(1, misalign_a + count)); AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); HWY_ASSERT(pa && pb && pc && expected); - T* a = pa.get() + misalign_a; + + const T sentinel = ConvertScalarTo(-42); + T* a = FillRandom(pa, count, misalign_a, sentinel, rng); T* b = pb.get() + misalign_b; T* c = pc.get() + misalign_a; for (size_t i = 0; i < count; ++i) { - a[i] = Random(rng); b[i] = Random(rng); c[i] = Random(rng); } @@ -278,10 +343,13 @@ struct TestTransform2 { #endif Transform2(d, a, count, b, c, fma4); - const auto info = hwy::detail::MakeTypeInfo(); - const char* target_name = hwy::TargetName(HWY_TARGET); - hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, - __FILE__, __LINE__); + AssertArraySimilar(expected.get(), a, count, hwy::TargetName(HWY_TARGET), + __FILE__, __LINE__); + // Ensure no out-of-bound writes. + for (size_t i = 0; i < misalign_a; ++i) { + HWY_ASSERT_EQ(sentinel, pa[i]); + } + HWY_ASSERT_EQ(sentinel, a[count]); } }; @@ -306,15 +374,13 @@ struct TestReplace { if (misalign_b != 0) return; if (count == 0) return; using T = TFromD; - AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + count); + AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + count + 1); AlignedFreeUniquePtr pb = AllocateAligned(count); AlignedFreeUniquePtr expected = AllocateAligned(count); HWY_ASSERT(pa && pb && expected); - T* a = pa.get() + misalign_a; - for (size_t i = 0; i < count; ++i) { - a[i] = Random(rng); - } + const T sentinel = ConvertScalarTo(-42); + T* a = FillRandom(pa, count, misalign_a, sentinel, rng); std::vector positions(AdjustedReps(count)); for (size_t& pos : positions) { @@ -333,9 +399,19 @@ struct TestReplace { Replace(d, a, count, new_t, old_t); HWY_ASSERT_ARRAY_EQ(expected.get(), a, count); + // Ensure no out-of-bound writes. + for (size_t i = 0; i < misalign_a; ++i) { + HWY_ASSERT_EQ(sentinel, pa[i]); + } + HWY_ASSERT_EQ(sentinel, a[count]); ReplaceIf(d, pb.get(), count, new_t, IfEq(old_t)); HWY_ASSERT_ARRAY_EQ(expected.get(), pb.get(), count); + // Ensure no out-of-bound writes. + for (size_t i = 0; i < misalign_a; ++i) { + HWY_ASSERT_EQ(sentinel, pa[i]); + } + HWY_ASSERT_EQ(sentinel, a[count]); } } }; @@ -345,6 +421,10 @@ void TestAllGenerate() { ForIntegerTypes(ForPartialVectors>()); } +void TestAllForeach() { + ForAllTypes(ForPartialVectors>()); +} + void TestAllTransform() { ForFloatTypes(ForPartialVectors>()); } @@ -361,20 +441,24 @@ void TestAllReplace() { ForFloatTypes(ForPartialVectors>()); } +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #if HWY_ONCE - namespace hwy { +namespace { HWY_BEFORE_TEST(TransformTest); HWY_EXPORT_AND_TEST_P(TransformTest, TestAllGenerate); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllForeach); HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform); HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform1); HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform2); HWY_EXPORT_AND_TEST_P(TransformTest, TestAllReplace); +HWY_AFTER_TEST(); +} // namespace } // namespace hwy - -#endif +HWY_TEST_MAIN(); +#endif // HWY_ONCE diff --git a/r/src/vendor/highway/hwy/contrib/math/math-inl.h b/r/src/vendor/highway/hwy/contrib/math/math-inl.h index d701c5e9..d7416845 100644 --- a/r/src/vendor/highway/hwy/contrib/math/math-inl.h +++ b/r/src/vendor/highway/hwy/contrib/math/math-inl.h @@ -22,6 +22,9 @@ #define HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ #endif +#include +#include + #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); @@ -37,7 +40,7 @@ namespace HWY_NAMESPACE { * @return arc cosine of 'x' */ template -HWY_INLINE V Acos(const D d, V x); +HWY_INLINE V Acos(D d, V x); template HWY_NOINLINE V CallAcos(const D d, VecArg x) { return Acos(d, x); @@ -52,7 +55,7 @@ HWY_NOINLINE V CallAcos(const D d, VecArg x) { * @return hyperbolic arc cosine of 'x' */ template -HWY_INLINE V Acosh(const D d, V x); +HWY_INLINE V Acosh(D d, V x); template HWY_NOINLINE V CallAcosh(const D d, VecArg x) { return Acosh(d, x); @@ -67,7 +70,7 @@ HWY_NOINLINE V CallAcosh(const D d, VecArg x) { * @return arc sine of 'x' */ template -HWY_INLINE V Asin(const D d, V x); +HWY_INLINE V Asin(D d, V x); template HWY_NOINLINE V CallAsin(const D d, VecArg x) { return Asin(d, x); @@ -82,7 +85,7 @@ HWY_NOINLINE V CallAsin(const D d, VecArg x) { * @return hyperbolic arc sine of 'x' */ template -HWY_INLINE V Asinh(const D d, V x); +HWY_INLINE V Asinh(D d, V x); template HWY_NOINLINE V CallAsinh(const D d, VecArg x) { return Asinh(d, x); @@ -97,7 +100,7 @@ HWY_NOINLINE V CallAsinh(const D d, VecArg x) { * @return arc tangent of 'x' */ template -HWY_INLINE V Atan(const D d, V x); +HWY_INLINE V Atan(D d, V x); template HWY_NOINLINE V CallAtan(const D d, VecArg x) { return Atan(d, x); @@ -112,7 +115,7 @@ HWY_NOINLINE V CallAtan(const D d, VecArg x) { * @return hyperbolic arc tangent of 'x' */ template -HWY_INLINE V Atanh(const D d, V x); +HWY_INLINE V Atanh(D d, V x); template HWY_NOINLINE V CallAtanh(const D d, VecArg x) { return Atanh(d, x); @@ -175,7 +178,7 @@ HWY_NOINLINE V CallAtan2(const D d, VecArg y, VecArg x) { * @return cosine of 'x' */ template -HWY_INLINE V Cos(const D d, V x); +HWY_INLINE V Cos(D d, V x); template HWY_NOINLINE V CallCos(const D d, VecArg x) { return Cos(d, x); @@ -190,12 +193,27 @@ HWY_NOINLINE V CallCos(const D d, VecArg x) { * @return e^x */ template -HWY_INLINE V Exp(const D d, V x); +HWY_INLINE V Exp(D d, V x); template HWY_NOINLINE V CallExp(const D d, VecArg x) { return Exp(d, x); } +/** + * Highway SIMD version of std::exp2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[-FLT_MAX, +128], float64[-DBL_MAX, +1024] + * @return 2^x + */ +template +HWY_INLINE V Exp2(D d, V x); +template +HWY_NOINLINE V CallExp2(const D d, VecArg x) { + return Exp2(d, x); +} + /** * Highway SIMD version of std::expm1(x). * @@ -205,7 +223,7 @@ HWY_NOINLINE V CallExp(const D d, VecArg x) { * @return e^x - 1 */ template -HWY_INLINE V Expm1(const D d, V x); +HWY_INLINE V Expm1(D d, V x); template HWY_NOINLINE V CallExpm1(const D d, VecArg x) { return Expm1(d, x); @@ -220,7 +238,7 @@ HWY_NOINLINE V CallExpm1(const D d, VecArg x) { * @return natural logarithm of 'x' */ template -HWY_INLINE V Log(const D d, V x); +HWY_INLINE V Log(D d, V x); template HWY_NOINLINE V CallLog(const D d, VecArg x) { return Log(d, x); @@ -235,7 +253,7 @@ HWY_NOINLINE V CallLog(const D d, VecArg x) { * @return base 10 logarithm of 'x' */ template -HWY_INLINE V Log10(const D d, V x); +HWY_INLINE V Log10(D d, V x); template HWY_NOINLINE V CallLog10(const D d, VecArg x) { return Log10(d, x); @@ -250,7 +268,7 @@ HWY_NOINLINE V CallLog10(const D d, VecArg x) { * @return log(1 + x) */ template -HWY_INLINE V Log1p(const D d, V x); +HWY_INLINE V Log1p(D d, V x); template HWY_NOINLINE V CallLog1p(const D d, VecArg x) { return Log1p(d, x); @@ -265,7 +283,7 @@ HWY_NOINLINE V CallLog1p(const D d, VecArg x) { * @return base 2 logarithm of 'x' */ template -HWY_INLINE V Log2(const D d, V x); +HWY_INLINE V Log2(D d, V x); template HWY_NOINLINE V CallLog2(const D d, VecArg x) { return Log2(d, x); @@ -280,7 +298,7 @@ HWY_NOINLINE V CallLog2(const D d, VecArg x) { * @return sine of 'x' */ template -HWY_INLINE V Sin(const D d, V x); +HWY_INLINE V Sin(D d, V x); template HWY_NOINLINE V CallSin(const D d, VecArg x) { return Sin(d, x); @@ -295,7 +313,7 @@ HWY_NOINLINE V CallSin(const D d, VecArg x) { * @return hyperbolic sine of 'x' */ template -HWY_INLINE V Sinh(const D d, V x); +HWY_INLINE V Sinh(D d, V x); template HWY_NOINLINE V CallSinh(const D d, VecArg x) { return Sinh(d, x); @@ -310,7 +328,7 @@ HWY_NOINLINE V CallSinh(const D d, VecArg x) { * @return hyperbolic tangent of 'x' */ template -HWY_INLINE V Tanh(const D d, V x); +HWY_INLINE V Tanh(D d, V x); template HWY_NOINLINE V CallTanh(const D d, VecArg x) { return Tanh(d, x); @@ -327,12 +345,27 @@ HWY_NOINLINE V CallTanh(const D d, VecArg x) { * @return sine and cosine of 'x' */ template -HWY_INLINE void SinCos(const D d, V x, V& s, V& c); +HWY_INLINE void SinCos(D d, V x, V& s, V& c); template -HWY_NOINLINE V CallSinCos(const D d, VecArg x, VecArg& s, VecArg& c) { +HWY_NOINLINE void CallSinCos(const D d, VecArg x, V& s, V& c) { SinCos(d, x, s, c); } +/** + * Highway SIMD version of Hypot + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hypotenuse of a and b + */ +template +HWY_INLINE V Hypot(D d, V a, V b); +template +HWY_NOINLINE V CallHypot(const D d, VecArg a, VecArg b) { + return Hypot(d, a, b); +} + //////////////////////////////////////////////////////////////////////////////// // Implementation //////////////////////////////////////////////////////////////////////////////// @@ -790,6 +823,12 @@ struct ExpImpl { return ConvertTo(Rebind(), x); } + // Rounds float to nearest int32_t + template + HWY_INLINE Vec> ToNearestInt32(D /*unused*/, V x) { + return NearestInt(x); + } + template HWY_INLINE V ExpPoly(D d, V x) { const auto k0 = Set(d, +0.5f); @@ -829,6 +868,13 @@ struct ExpImpl { x = MulAdd(qf, kLn2Part1f, x); return x; } + + template + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V x_frac = Sub(x, ConvertTo(d, q)); + return MulAdd(x_frac, Set(d, 0.193147182464599609375f), + Mul(x_frac, Set(d, 0.5f))); + } }; template <> @@ -864,6 +910,12 @@ struct ExpImpl { return DemoteTo(Rebind(), x); } + // Rounds double to nearest int32_t + template + HWY_INLINE Vec> ToNearestInt32(D /*unused*/, V x) { + return DemoteToNearestInt(Rebind(), x); + } + template HWY_INLINE V ExpPoly(D d, V x) { const auto k0 = Set(d, +0.5); @@ -910,6 +962,13 @@ struct ExpImpl { x = MulAdd(qf, kLn2Part1d, x); return x; } + + template + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V x_frac = Sub(x, PromoteTo(d, q)); + return MulAdd(x_frac, Set(d, 0.1931471805599453139823396), + Mul(x_frac, Set(d, 0.5))); + } }; template <> @@ -1029,13 +1088,13 @@ HWY_INLINE void SinCos3(D d, TFromD dp1, TFromD dp2, TFromD dp3, V x, const VI ci_1 = Set(di, 1); const VI ci_2 = Set(di, 2); const VI ci_4 = Set(di, 4); - const V cos_p0 = Set(d, T(2.443315711809948E-005)); - const V cos_p1 = Set(d, T(-1.388731625493765E-003)); - const V cos_p2 = Set(d, T(4.166664568298827E-002)); - const V sin_p0 = Set(d, T(-1.9515295891E-4)); - const V sin_p1 = Set(d, T(8.3321608736E-3)); - const V sin_p2 = Set(d, T(-1.6666654611E-1)); - const V FOPI = Set(d, T(1.27323954473516)); // 4 / M_PI + const V cos_p0 = Set(d, ConvertScalarTo(2.443315711809948E-005)); + const V cos_p1 = Set(d, ConvertScalarTo(-1.388731625493765E-003)); + const V cos_p2 = Set(d, ConvertScalarTo(4.166664568298827E-002)); + const V sin_p0 = Set(d, ConvertScalarTo(-1.9515295891E-4)); + const V sin_p1 = Set(d, ConvertScalarTo(8.3321608736E-3)); + const V sin_p2 = Set(d, ConvertScalarTo(-1.6666654611E-1)); + const V FOPI = Set(d, ConvertScalarTo(1.27323954473516)); // 4 / M_PI const V DP1 = Set(d, dp1); const V DP2 = Set(d, dp2); const V DP3 = Set(d, dp3); @@ -1128,19 +1187,20 @@ HWY_INLINE void SinCos6(D d, TFromD dp1, TFromD dp2, TFromD dp3, V x, const VI ci_1 = Set(di, 1); const VI ci_2 = Set(di, 2); const VI ci_4 = Set(di, 4); - const V cos_p0 = Set(d, T(-1.13585365213876817300E-11)); - const V cos_p1 = Set(d, T(2.08757008419747316778E-9)); - const V cos_p2 = Set(d, T(-2.75573141792967388112E-7)); - const V cos_p3 = Set(d, T(2.48015872888517045348E-5)); - const V cos_p4 = Set(d, T(-1.38888888888730564116E-3)); - const V cos_p5 = Set(d, T(4.16666666666665929218E-2)); - const V sin_p0 = Set(d, T(1.58962301576546568060E-10)); - const V sin_p1 = Set(d, T(-2.50507477628578072866E-8)); - const V sin_p2 = Set(d, T(2.75573136213857245213E-6)); - const V sin_p3 = Set(d, T(-1.98412698295895385996E-4)); - const V sin_p4 = Set(d, T(8.33333333332211858878E-3)); - const V sin_p5 = Set(d, T(-1.66666666666666307295E-1)); - const V FOPI = Set(d, T(1.2732395447351626861510701069801148)); // 4 / M_PI + const V cos_p0 = Set(d, ConvertScalarTo(-1.13585365213876817300E-11)); + const V cos_p1 = Set(d, ConvertScalarTo(2.08757008419747316778E-9)); + const V cos_p2 = Set(d, ConvertScalarTo(-2.75573141792967388112E-7)); + const V cos_p3 = Set(d, ConvertScalarTo(2.48015872888517045348E-5)); + const V cos_p4 = Set(d, ConvertScalarTo(-1.38888888888730564116E-3)); + const V cos_p5 = Set(d, ConvertScalarTo(4.16666666666665929218E-2)); + const V sin_p0 = Set(d, ConvertScalarTo(1.58962301576546568060E-10)); + const V sin_p1 = Set(d, ConvertScalarTo(-2.50507477628578072866E-8)); + const V sin_p2 = Set(d, ConvertScalarTo(2.75573136213857245213E-6)); + const V sin_p3 = Set(d, ConvertScalarTo(-1.98412698295895385996E-4)); + const V sin_p4 = Set(d, ConvertScalarTo(8.33333333332211858878E-3)); + const V sin_p5 = Set(d, ConvertScalarTo(-1.66666666666666307295E-1)); + const V FOPI = // 4 / M_PI + Set(d, ConvertScalarTo(1.2732395447351626861510701069801148)); const V DP1 = Set(d, dp1); const V DP2 = Set(d, dp2); const V DP3 = Set(d, dp3); @@ -1426,6 +1486,25 @@ HWY_INLINE V Exp(const D d, V x) { return IfThenElseZero(Ge(x, kLowerBound), y); } +template +HWY_INLINE V Exp2(const D d, V x) { + using T = TFromD; + + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -150.0 : -1075.0))); + const V kOne = Set(d, static_cast(+1.0)); + + impl::ExpImpl impl; + + // q = static_cast(std::lrint(x)) + const auto q = impl.ToNearestInt32(d, x); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.Exp2Reduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + template HWY_INLINE V Expm1(const D d, V x) { using T = TFromD; @@ -1541,6 +1620,130 @@ HWY_INLINE void SinCos(const D d, V x, V& s, V& c) { impl.SinCos(d, x, s, c); } +template +HWY_INLINE V Hypot(const D d, V a, V b) { + using T = TFromD; + using TI = MakeSigned; + const RebindToUnsigned du; + const RebindToSigned di; + using VI = VFromD; + + constexpr int kMaxBiasedExp = static_cast(MaxExponentField()); + static_assert(kMaxBiasedExp > 0, "kMaxBiasedExp > 0 must be true"); + + constexpr int kNumOfMantBits = MantissaBits(); + static_assert(kNumOfMantBits > 0, "kNumOfMantBits > 0 must be true"); + + constexpr int kExpBias = kMaxBiasedExp / 2; + + static_assert( + static_cast(kExpBias) + static_cast(kNumOfMantBits) < + static_cast(kMaxBiasedExp), + "kExpBias + kNumOfMantBits < kMaxBiasedExp must be true"); + + // kMinValToSquareBiasedExp is the smallest biased exponent such that + // pow(pow(2, kMinValToSquareBiasedExp - kExpBias) * x, 2) is either a normal + // floating-point value or infinity if x is a non-zero, non-NaN value + constexpr int kMinValToSquareBiasedExp = (kExpBias / 2) + kNumOfMantBits; + static_assert(kMinValToSquareBiasedExp < kExpBias, + "kMinValToSquareBiasedExp < kExpBias must be true"); + + // kMaxValToSquareBiasedExp is the largest biased exponent such that + // pow(pow(2, kMaxValToSquareBiasedExp - kExpBias) * x, 2) * 2 is guaranteed + // to be a finite value if x is a finite value + constexpr int kMaxValToSquareBiasedExp = kExpBias + ((kExpBias / 2) - 1); + static_assert(kMaxValToSquareBiasedExp > kExpBias, + "kMaxValToSquareBiasedExp > kExpBias must be true"); + static_assert(kMaxValToSquareBiasedExp < kMaxBiasedExp, + "kMaxValToSquareBiasedExp < kMaxBiasedExp must be true"); + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 || \ + HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 + using TExpSatSub = MakeUnsigned; + using TExpMinMax = TI; +#else + using TExpSatSub = uint16_t; + using TExpMinMax = int16_t; +#endif + + const Repartition d_exp_sat_sub; + const Repartition d_exp_min_max; + + const V abs_a = Abs(a); + const V abs_b = Abs(b); + + const MFromD either_inf = Or(IsInf(a), IsInf(b)); + + const VI zero = Zero(di); + + // exp_a[i] is the biased exponent of abs_a[i] + const VI exp_a = BitCast(di, ShiftRight(BitCast(du, abs_a))); + + // exp_b[i] is the biased exponent of abs_b[i] + const VI exp_b = BitCast(di, ShiftRight(BitCast(du, abs_b))); + + // max_exp[i] is equal to HWY_MAX(exp_a[i], exp_b[i]) + + // If abs_a[i] and abs_b[i] are both NaN values, max_exp[i] will be equal to + // the biased exponent of the larger value. Otherwise, if either abs_a[i] or + // abs_b[i] is NaN, max_exp[i] will be equal to kMaxBiasedExp. + const VI max_exp = BitCast( + di, Max(BitCast(d_exp_min_max, exp_a), BitCast(d_exp_min_max, exp_b))); + + // If either abs_a[i] or abs_b[i] is zero, min_exp[i] is equal to max_exp[i]. + // Otherwise, if abs_a[i] and abs_b[i] are both nonzero, min_exp[i] is equal + // to HWY_MIN(exp_a[i], exp_b[i]). + const VI min_exp = IfThenElse( + Or(Eq(BitCast(di, abs_a), zero), Eq(BitCast(di, abs_b), zero)), max_exp, + BitCast(di, Min(BitCast(d_exp_min_max, exp_a), + BitCast(d_exp_min_max, exp_b)))); + + // scl_pow2[i] is the power of 2 to scale abs_a[i] and abs_b[i] by + + // abs_a[i] and abs_b[i] should be scaled by a factor that is greater than + // zero but less than or equal to + // pow(2, kMaxValToSquareBiasedExp - max_exp[i]) to ensure that that the + // multiplications or addition operations do not overflow if + // std::hypot(abs_a[i], abs_b[i]) is finite + + // If either abs_a[i] or abs_b[i] is a a positive value that is less than + // pow(2, kMinValToSquareBiasedExp - kExpBias), then scaling up abs_a[i] and + // abs_b[i] by pow(2, kMinValToSquareBiasedExp - min_exp[i]) will ensure that + // the multiplications and additions result in normal floating point values, + // infinities, or NaNs. + + // If HWY_MAX(kMinValToSquareBiasedExp - min_exp[i], 0) is greater than + // kMaxValToSquareBiasedExp - max_exp[i], scale abs_a[i] and abs_b[i] up by + // pow(2, kMaxValToSquareBiasedExp - max_exp[i]) to ensure that the + // multiplication and addition operations result in a finite result if + // std::hypot(abs_a[i], abs_b[i]) is finite. + + const VI scl_pow2 = BitCast( + di, + Min(BitCast(d_exp_min_max, + SaturatedSub(BitCast(d_exp_sat_sub, + Set(di, static_cast( + kMinValToSquareBiasedExp))), + BitCast(d_exp_sat_sub, min_exp))), + BitCast(d_exp_min_max, + Sub(Set(di, static_cast(kMaxValToSquareBiasedExp)), + max_exp)))); + + const VI exp_bias = Set(di, static_cast(kExpBias)); + + const V ab_scl_factor = + BitCast(d, ShiftLeft(Add(exp_bias, scl_pow2))); + const V hypot_scl_factor = + BitCast(d, ShiftLeft(Sub(exp_bias, scl_pow2))); + + const V scl_a = Mul(abs_a, ab_scl_factor); + const V scl_b = Mul(abs_b, ab_scl_factor); + + const V scl_hypot = Sqrt(MulAdd(scl_a, scl_a, Mul(scl_b, scl_b))); + // std::hypot returns inf if one input is +/- inf, even if the other is NaN. + return IfThenElse(either_inf, Inf(d), Mul(scl_hypot, hypot_scl_factor)); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/r/src/vendor/highway/hwy/contrib/math/math_test.cc b/r/src/vendor/highway/hwy/contrib/math/math_test.cc index 6de83743..ef9eec3d 100644 --- a/r/src/vendor/highway/hwy/contrib/math/math_test.cc +++ b/r/src/vendor/highway/hwy/contrib/math/math_test.cc @@ -13,11 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include // FLT_MAX #include // std::abs +#include "hwy/base.h" +#include "hwy/nanobenchmark.h" + // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "hwy/contrib/math/math_test.cc" @@ -30,6 +34,7 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +namespace { // We have had test failures caused by excess precision due to keeping // intermediate results in 80-bit x87 registers. One such failure mode is that @@ -39,21 +44,15 @@ namespace HWY_NAMESPACE { #if HWY_ARCH_X86_32 && HWY_COMPILER_GCC_ACTUAL && \ (HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128) -// On 32-bit x86 with GCC 13+, build with `-fexcess-precision=standard` - see +// GCC 13+: because CMAKE_CXX_EXTENSIONS is OFF, we build with -std= and hence +// also -fexcess-precision=standard, so there is no problem. See #1708 and // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=323. #if HWY_COMPILER_GCC_ACTUAL >= 1300 - -#if FLT_EVAL_METHOD == 0 // correct flag given, no problem #define HWY_MATH_TEST_EXCESS_PRECISION 0 -#else -#define HWY_MATH_TEST_EXCESS_PRECISION 1 -#pragma message( \ - "Skipping scalar math_test on 32-bit x86 GCC 13+ without -fexcess-precision=standard") -#endif // FLT_EVAL_METHOD #else // HWY_COMPILER_GCC_ACTUAL < 1300 -// On 32-bit x86 with GCC <13, set HWY_CMAKE_SSE2 - see +// The build system must enable SSE2, e.g. via HWY_CMAKE_SSE2 - see // https://stackoverflow.com/questions/20869904/c-handling-of-excess-precision . #if defined(__SSE2__) // correct flag given, no problem #define HWY_MATH_TEST_EXCESS_PRECISION 0 @@ -68,14 +67,6 @@ namespace HWY_NAMESPACE { #define HWY_MATH_TEST_EXCESS_PRECISION 0 #endif // HWY_ARCH_X86_32 etc -template -inline Out BitCast(const In& in) { - static_assert(sizeof(Out) == sizeof(In), ""); - Out out; - CopyBytes(&in, &out); - return out; -} - template HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T), Vec (*fxN)(D, VecArg>), D d, T min, T max, @@ -84,24 +75,24 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T), static bool once = true; if (once) { once = false; - fprintf(stderr, - "Skipping math_test due to GCC issue with excess precision.\n"); + HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n"); } + return; } using UintT = MakeUnsigned; - const UintT min_bits = BitCast(min); - const UintT max_bits = BitCast(max); + const UintT min_bits = BitCastScalar(min); + const UintT max_bits = BitCastScalar(max); // If min is negative and max is positive, the range needs to be broken into // two pieces, [+0, max] and [-0, min], otherwise [min, max]. int range_count = 1; UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}}; if ((min < 0.0) && (max > 0.0)) { - ranges[0][0] = BitCast(static_cast(+0.0)); + ranges[0][0] = BitCastScalar(ConvertScalarTo(+0.0)); ranges[0][1] = max_bits; - ranges[1][0] = BitCast(static_cast(-0.0)); + ranges[1][0] = BitCastScalar(ConvertScalarTo(-0.0)); ranges[1][1] = min_bits; range_count = 2; } @@ -116,7 +107,8 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T), for (UintT value_bits = start; value_bits <= stop; value_bits += step) { // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise // value_bits can be less than start, and thus possibly NaN. - const T value = BitCast(HWY_MIN(HWY_MAX(start, value_bits), stop)); + const T value = + BitCastScalar(HWY_MIN(HWY_MAX(start, value_bits), stop)); const T actual = GetLane(fxN(d, Set(d, value))); const T expected = fx1(value); @@ -130,7 +122,7 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T), const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected); max_ulp = HWY_MAX(max_ulp, ulp); if (ulp > max_error_ulp) { - fprintf(stderr, "%s: %s(%f) expected %f actual %f ulp %g max ulp %u\n", + fprintf(stderr, "%s: %s(%f) expected %E actual %E ulp %g max ulp %u\n", hwy::TypeName(T(), Lanes(d)).c_str(), name, value, expected, actual, static_cast(ulp), static_cast(max_error_ulp)); @@ -165,9 +157,11 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T), }; \ DEFINE_MATH_TEST_FUNC(NAME) -// Floating point values closest to but less than 1.0 -const float kNearOneF = BitCast(0x3F7FFFFF); -const double kNearOneD = BitCast(0x3FEFFFFFFFFFFFFFULL); +// Floating point values closest to but less than 1.0. Avoid variables with +// static initializers inside HWY_BEFORE_NAMESPACE/HWY_AFTER_NAMESPACE to +// ensure target-specific code does not leak into startup code. +float kNearOneF() { return BitCastScalar(0x3F7FFFFF); } +double kNearOneD() { return BitCastScalar(0x3FEFFFFFFFFFFFFFULL); } // The discrepancy is unacceptably large for MSYS2 (less accurate libm?), so // only increase the error tolerance there. @@ -190,14 +184,14 @@ constexpr uint64_t ACosh32ULP() { template static Vec SinCosSin(const D d, VecArg> x) { Vec s, c; - SinCos(d, x, s, c); + CallSinCos(d, x, s, c); return s; } template static Vec SinCosCos(const D d, VecArg> x) { Vec s, c; - SinCos(d, x, s, c); + CallSinCos(d, x, s, c); return c; } @@ -234,15 +228,19 @@ DEFINE_MATH_TEST(Asinh, DEFINE_MATH_TEST(Atan, std::atan, CallAtan, -FLT_MAX, +FLT_MAX, 3, std::atan, CallAtan, -DBL_MAX, +DBL_MAX, 3) +// NEON has ULP 4 instead of 3 DEFINE_MATH_TEST(Atanh, - std::atanh, CallAtanh, -kNearOneF, +kNearOneF, 4, // NEON is 4 instead of 3 - std::atanh, CallAtanh, -kNearOneD, +kNearOneD, 3) + std::atanh, CallAtanh, -kNearOneF(), +kNearOneF(), 4, + std::atanh, CallAtanh, -kNearOneD(), +kNearOneD(), 3) DEFINE_MATH_TEST(Cos, std::cos, CallCos, -39000.0f, +39000.0f, 3, std::cos, CallCos, -39000.0, +39000.0, Cos64ULP()) DEFINE_MATH_TEST(Exp, std::exp, CallExp, -FLT_MAX, +104.0f, 1, std::exp, CallExp, -DBL_MAX, +104.0, 1) +DEFINE_MATH_TEST(Exp2, + std::exp2, CallExp2, -FLT_MAX, +128.0f, 2, + std::exp2, CallExp2, -DBL_MAX, +128.0, 2) DEFINE_MATH_TEST(Expm1, std::expm1, CallExpm1, -FLT_MAX, +104.0f, 4, std::expm1, CallExpm1, -DBL_MAX, +104.0, 4) @@ -285,33 +283,38 @@ void Atan2TestCases(T /*unused*/, D d, size_t& padded, T x; T expected; }; - const T pos = static_cast(1E5); - const T neg = static_cast(-1E7); - // T{-0} is not enough to get an actual negative zero. - const T n0 = static_cast(-0.0); + const T pos = ConvertScalarTo(1E5); + const T neg = ConvertScalarTo(-1E7); + const T p0 = ConvertScalarTo(0); + // -0 is not enough to get an actual negative zero. + const T n0 = ConvertScalarTo(-0.0); + const T p1 = ConvertScalarTo(1); + const T n1 = ConvertScalarTo(-1); + const T p2 = ConvertScalarTo(2); + const T n2 = ConvertScalarTo(-2); const T inf = GetLane(Inf(d)); const T nan = GetLane(NaN(d)); - const T pi = static_cast(3.141592653589793238); - const YX test_cases[] = { // 45 degree steps: - {T{0.0}, T{1.0}, T{0}}, // E - {T{-1.0}, T{1.0}, -pi / 4}, // SE - {T{-1.0}, T{0.0}, -pi / 2}, // S - {T{-1.0}, T{-1.0}, -3 * pi / 4}, // SW - {T{0.0}, T{-1.0}, pi}, // W - {T{1.0}, T{-1.0}, 3 * pi / 4}, // NW - {T{1.0}, T{0.0}, pi / 2}, // N - {T{1.0}, T{1.0}, pi / 4}, // NE + const T pi = ConvertScalarTo(3.141592653589793238); + const YX test_cases[] = { // 45 degree steps: + {p0, p1, p0}, // E + {n1, p1, -pi / 4}, // SE + {n1, p0, -pi / 2}, // S + {n1, n1, -3 * pi / 4}, // SW + {p0, n1, pi}, // W + {p1, n1, 3 * pi / 4}, // NW + {p1, p0, pi / 2}, // N + {p1, p1, pi / 4}, // NE // y = ±0, x < 0 or -0 - {T{0}, T{-1}, pi}, - {n0, T{-2}, -pi}, + {p0, n1, pi}, + {n0, n2, -pi}, // y = ±0, x > 0 or +0 - {T{0}, T{2}, T{0}}, - {n0, T{2}, n0}, + {p0, p2, p0}, + {n0, p2, n0}, // y = ±∞, x finite - {inf, T{3}, pi / 2}, - {-inf, T{3}, -pi / 2}, + {inf, p2, pi / 2}, + {-inf, p2, -pi / 2}, // y = ±∞, x = -∞ {inf, -inf, 3 * pi / 4}, {-inf, -inf, -3 * pi / 4}, @@ -319,21 +322,21 @@ void Atan2TestCases(T /*unused*/, D d, size_t& padded, {inf, inf, pi / 4}, {-inf, inf, -pi / 4}, // y < 0, x = ±0 - {T{-2}, T{0}, -pi / 2}, - {T{-1}, n0, -pi / 2}, + {n2, p0, -pi / 2}, + {n1, n0, -pi / 2}, // y > 0, x = ±0 - {pos, T{0}, pi / 2}, - {T{4}, n0, pi / 2}, + {pos, p0, pi / 2}, + {p2, n0, pi / 2}, // finite y > 0, x = -∞ {pos, -inf, pi}, // finite y < 0, x = -∞ {neg, -inf, -pi}, // finite y > 0, x = +∞ - {pos, inf, T{0}}, + {pos, inf, p0}, // finite y < 0, x = +∞ {neg, inf, n0}, // y NaN xor x NaN - {nan, T{0}, nan}, + {nan, p0, nan}, {pos, nan, nan}}; const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); const size_t N = Lanes(d); @@ -341,7 +344,7 @@ void Atan2TestCases(T /*unused*/, D d, size_t& padded, out_y = AllocateAligned(padded); out_x = AllocateAligned(padded); out_expected = AllocateAligned(padded); - HWY_ASSERT(out_y && out_x); + HWY_ASSERT(out_y && out_x && out_expected); size_t i = 0; for (; i < kNumTestCases; ++i) { out_y[i] = test_cases[i].y; @@ -349,9 +352,9 @@ void Atan2TestCases(T /*unused*/, D d, size_t& padded, out_expected[i] = test_cases[i].expected; } for (; i < padded; ++i) { - out_y[i] = T{0}; - out_x[i] = T{0}; - out_expected[i] = T{0}; + out_y[i] = p0; + out_x[i] = p0; + out_expected[i] = p0; } } @@ -364,10 +367,10 @@ struct TestAtan2 { AlignedFreeUniquePtr in_y, in_x, expected; Atan2TestCases(t, d, padded, in_y, in_x, expected); - const Vec tolerance = Set(d, T(1E-5)); + const Vec tolerance = Set(d, ConvertScalarTo(1E-5)); for (size_t i = 0; i < padded; ++i) { - const T actual = static_cast(atan2(in_y[i], in_x[i])); + const T actual = ConvertScalarTo(atan2(in_y[i], in_x[i])); // fprintf(stderr, "%zu: table %f atan2 %f\n", i, expected[i], actual); HWY_ASSERT_EQ(expected[i], actual); } @@ -393,7 +396,7 @@ struct TestAtan2 { if (!AllTrue(d, ok)) { const size_t mismatch = static_cast(FindKnownFirstTrue(d, Not(ok))); - fprintf(stderr, "Mismatch for i=%d expected %f actual %f\n", + fprintf(stderr, "Mismatch for i=%d expected %E actual %E\n", static_cast(i + mismatch), expected[i + mismatch], ExtractLane(actual, mismatch)); HWY_ASSERT(0); @@ -408,14 +411,229 @@ HWY_NOINLINE void TestAllAtan2() { ForFloat3264Types(ForPartialVectors()); } +template +void HypotTestCases(T /*unused*/, D d, size_t& padded, + AlignedFreeUniquePtr& out_a, + AlignedFreeUniquePtr& out_b, + AlignedFreeUniquePtr& out_expected) { + using TU = MakeUnsigned; + + struct AB { + T a; + T b; + }; + + constexpr int kNumOfMantBits = MantissaBits(); + static_assert(kNumOfMantBits > 0, "kNumOfMantBits > 0 must be true"); + + // Ensures inputs are not constexpr. + const TU u1 = static_cast(hwy::Unpredictable1()); + const double k1 = static_cast(u1); + + const T pos = ConvertScalarTo(1E5 * k1); + const T neg = ConvertScalarTo(-1E7 * k1); + const T p0 = ConvertScalarTo(k1 - 1.0); + // -0 is not enough to get an actual negative zero. + const T n0 = ScalarCopySign(p0, neg); + const T p1 = ConvertScalarTo(k1); + const T n1 = ConvertScalarTo(-k1); + const T p2 = ConvertScalarTo(2 * k1); + const T n2 = ConvertScalarTo(-2 * k1); + const T inf = BitCastScalar(ExponentMask() * u1); + const T neg_inf = ScalarCopySign(inf, n1); + const T nan = BitCastScalar( + static_cast(ExponentMask() | (u1 << (kNumOfMantBits - 1)))); + + const double max_as_f64 = ConvertScalarTo(HighestValue()) * k1; + const T max = ConvertScalarTo(max_as_f64); + + const T huge = ConvertScalarTo(max_as_f64 * 0.25); + const T neg_huge = ScalarCopySign(huge, n1); + + const T huge2 = ConvertScalarTo(max_as_f64 * 0.039415044328304796); + + const T large = ConvertScalarTo(3.512227595593985E18 * k1); + const T neg_large = ScalarCopySign(large, n1); + const T large2 = ConvertScalarTo(2.1190576943127544E16 * k1); + + const T small = ConvertScalarTo(1.067033284841808E-11 * k1); + const T neg_small = ScalarCopySign(small, n1); + const T small2 = ConvertScalarTo(1.9401409532292856E-12 * k1); + + const T tiny = BitCastScalar(static_cast(u1 << kNumOfMantBits)); + const T neg_tiny = ScalarCopySign(tiny, n1); + + const T tiny2 = + ConvertScalarTo(78.68466968859765 * ConvertScalarTo(tiny)); + + const AB test_cases[] = {{p0, p0}, {p0, n0}, + {n0, n0}, {p1, p1}, + {p1, n1}, {n1, n1}, + {p2, p2}, {p2, n2}, + {p2, pos}, {p2, neg}, + {n2, pos}, {n2, neg}, + {n2, n2}, {p0, tiny}, + {p0, neg_tiny}, {n0, tiny}, + {n0, neg_tiny}, {p1, tiny}, + {p1, neg_tiny}, {n1, tiny}, + {n1, neg_tiny}, {tiny, p0}, + {tiny2, p0}, {tiny, tiny2}, + {neg_tiny, tiny2}, {huge, huge2}, + {neg_huge, huge2}, {huge, p0}, + {huge, tiny}, {huge2, tiny2}, + {large, p0}, {large, large2}, + {neg_large, p0}, {neg_large, large2}, + {small, p0}, {small, small2}, + {neg_small, p0}, {neg_small, small2}, + {max, p0}, {max, huge}, + {max, max}, {p0, inf}, + {n0, inf}, {p1, inf}, + {n1, inf}, {p2, inf}, + {n2, inf}, {p0, neg_inf}, + {n0, neg_inf}, {p1, neg_inf}, + {n1, neg_inf}, {p2, neg_inf}, + {n2, neg_inf}, {p0, nan}, + {n0, nan}, {p1, nan}, + {n1, nan}, {p2, nan}, + {n2, nan}, {huge, inf}, + {inf, nan}, {neg_inf, nan}, + {nan, nan}}; + + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + out_a = AllocateAligned(padded); + out_b = AllocateAligned(padded); + out_expected = AllocateAligned(padded); + HWY_ASSERT(out_a && out_b && out_expected); + + size_t i = 0; + for (; i < kNumTestCases; ++i) { + const T a = + test_cases[i].a * hwy::ConvertScalarTo(hwy::Unpredictable1()); + const T b = test_cases[i].b; + +#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7 + // Ignore test cases that have infinite or NaN inputs on Armv7 NEON + if (!ScalarIsFinite(a) || !ScalarIsFinite(b)) { + out_a[i] = p0; + out_b[i] = p0; + out_expected[i] = p0; + continue; + } +#endif + + out_a[i] = a; + out_b[i] = b; + + if (ScalarIsInf(a) || ScalarIsInf(b)) { + out_expected[i] = inf; + } else if (ScalarIsNaN(a) || ScalarIsNaN(b)) { + out_expected[i] = nan; + } else { + out_expected[i] = std::hypot(a, b); + } + } + for (; i < padded; ++i) { + out_a[i] = p0; + out_b[i] = p0; + out_expected[i] = p0; + } +} + +struct TestHypot { + template + HWY_NOINLINE void operator()(T t, D d) { + if (HWY_MATH_TEST_EXCESS_PRECISION) { + return; + } + + const size_t N = Lanes(d); + + constexpr uint64_t kMaxErrorUlp = 4; + + size_t padded; + AlignedFreeUniquePtr in_a, in_b, expected; + HypotTestCases(t, d, padded, in_a, in_b, expected); + + auto actual1_lanes = AllocateAligned(N); + auto actual2_lanes = AllocateAligned(N); + HWY_ASSERT(actual1_lanes && actual2_lanes); + + uint64_t max_ulp = 0; + for (size_t i = 0; i < padded; i += N) { + const auto a = Load(d, in_a.get() + i); + const auto b = Load(d, in_b.get() + i); + +#if HWY_ARCH_ARM_A64 + // TODO(b/287462770): inline to work around incorrect SVE codegen + const auto actual1 = Hypot(d, a, b); + const auto actual2 = Hypot(d, b, a); +#else + const auto actual1 = CallHypot(d, a, b); + const auto actual2 = CallHypot(d, b, a); +#endif + + Store(actual1, d, actual1_lanes.get()); + Store(actual2, d, actual2_lanes.get()); + + for (size_t j = 0; j < N; j++) { + const T val_a = in_a[i + j]; + const T val_b = in_b[i + j]; + const T expected_val = expected[i + j]; + const T actual1_val = actual1_lanes[j]; + const T actual2_val = actual2_lanes[j]; + + const auto ulp1 = + hwy::detail::ComputeUlpDelta(actual1_val, expected_val); + if (ulp1 > kMaxErrorUlp) { + fprintf(stderr, + "%s: Hypot(%e, %e) lane %d expected %E actual %E ulp %g max " + "ulp %u\n", + hwy::TypeName(T(), Lanes(d)).c_str(), val_a, val_b, + static_cast(j), expected_val, actual1_val, + static_cast(ulp1), + static_cast(kMaxErrorUlp)); + } + + const auto ulp2 = + hwy::detail::ComputeUlpDelta(actual2_val, expected_val); + if (ulp2 > kMaxErrorUlp) { + fprintf(stderr, + "%s: Hypot(%e, %e) expected %E actual %E ulp %g max ulp %u\n", + hwy::TypeName(T(), Lanes(d)).c_str(), val_b, val_a, + expected_val, actual2_val, static_cast(ulp2), + static_cast(kMaxErrorUlp)); + } + + max_ulp = HWY_MAX(max_ulp, HWY_MAX(ulp1, ulp2)); + } + } + + if (max_ulp != 0) { + fprintf(stderr, "%s: Hypot max_ulp %g\n", + hwy::TypeName(T(), Lanes(d)).c_str(), + static_cast(max_ulp)); + HWY_ASSERT(max_ulp <= kMaxErrorUlp); + } + } +}; + +HWY_NOINLINE void TestAllHypot() { + if (HWY_MATH_TEST_EXCESS_PRECISION) return; + + ForFloat3264Types(ForPartialVectors()); +} + +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #if HWY_ONCE - namespace hwy { +namespace { HWY_BEFORE_TEST(HwyMathTest); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcos); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcosh); @@ -425,6 +643,7 @@ HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtan); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtanh); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllCos); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp2); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10); @@ -436,6 +655,9 @@ HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllTanh); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtan2); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSinCosSin); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSinCosCos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllHypot); +HWY_AFTER_TEST(); +} // namespace } // namespace hwy - -#endif +HWY_TEST_MAIN(); +#endif // HWY_ONCE diff --git a/r/src/vendor/highway/hwy/detect_compiler_arch.h b/r/src/vendor/highway/hwy/detect_compiler_arch.h index 081b6fff..94b49cb7 100644 --- a/r/src/vendor/highway/hwy/detect_compiler_arch.h +++ b/r/src/vendor/highway/hwy/detect_compiler_arch.h @@ -73,7 +73,13 @@ // https://github.com/simd-everywhere/simde/blob/47d6e603de9d04ee05cdfbc57cf282a02be1bf2a/simde/simde-detect-clang.h#L59. // Please send updates below to them as well, thanks! #if defined(__apple_build_version__) || __clang_major__ >= 999 -#if __has_attribute(nouwtable) // no new warnings in 16.0 +#if __has_warning("-Woverriding-option") +#define HWY_COMPILER_CLANG 1801 +// No new warnings in 17.0, and Apple LLVM 15.3, which should be 1600, already +// has the unsafe_buffer_usage attribute, so we instead check for new builtins. +#elif __has_builtin(__builtin_nondeterministic_value) +#define HWY_COMPILER_CLANG 1700 +#elif __has_attribute(nouwtable) // no new warnings in 16.0 #define HWY_COMPILER_CLANG 1600 #elif __has_warning("-Warray-parameter") #define HWY_COMPILER_CLANG 1500 @@ -113,7 +119,8 @@ #define HWY_COMPILER3_CLANG 0 #endif -#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && !HWY_COMPILER_ICC +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && !HWY_COMPILER_ICC && \ + !HWY_COMPILER_ICX #define HWY_COMPILER_GCC_ACTUAL HWY_COMPILER_GCC #else #define HWY_COMPILER_GCC_ACTUAL 0 @@ -121,17 +128,20 @@ // More than one may be nonzero, but we want at least one. #if 0 == (HWY_COMPILER_MSVC + HWY_COMPILER_CLANGCL + HWY_COMPILER_ICC + \ - HWY_COMPILER_GCC + HWY_COMPILER_CLANG) + HWY_COMPILER_ICX + HWY_COMPILER_GCC + HWY_COMPILER_CLANG) #error "Unsupported compiler" #endif -// We should only detect one of these (only clang/clangcl overlap) -#if 1 < \ - (!!HWY_COMPILER_MSVC + !!HWY_COMPILER_ICC + !!HWY_COMPILER_GCC_ACTUAL + \ - !!(HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG)) +// We should only detect one of these (only clang/clangcl/icx overlap) +#if 1 < (!!HWY_COMPILER_MSVC + (!!HWY_COMPILER_ICC & !HWY_COMPILER_ICX) + \ + !!HWY_COMPILER_GCC_ACTUAL + \ + !!(HWY_COMPILER_ICX | HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG)) #error "Detected multiple compilers" #endif +//------------------------------------------------------------------------------ +// Compiler features and C++ version + #ifdef __has_builtin #define HWY_HAS_BUILTIN(name) __has_builtin(name) #else @@ -156,6 +166,32 @@ #define HWY_HAS_FEATURE(name) 0 #endif +// NOTE: clang ~17 does not correctly handle wrapping __has_include in a macro. + +#if HWY_COMPILER_MSVC && defined(_MSVC_LANG) && _MSVC_LANG > __cplusplus +#define HWY_CXX_LANG _MSVC_LANG +#else +#define HWY_CXX_LANG __cplusplus +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201603L +#define HWY_CXX17_CONSTEXPR constexpr +#else +#define HWY_CXX17_CONSTEXPR +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304L +#define HWY_CXX14_CONSTEXPR constexpr +#else +#define HWY_CXX14_CONSTEXPR +#endif + +#if HWY_CXX_LANG >= 201703L +#define HWY_IF_CONSTEXPR if constexpr +#else +#define HWY_IF_CONSTEXPR if +#endif + //------------------------------------------------------------------------------ // Architecture @@ -187,6 +223,12 @@ #define HWY_ARCH_PPC 0 #endif +#if defined(__powerpc64__) || (HWY_ARCH_PPC && defined(__64BIT__)) +#define HWY_ARCH_PPC_64 1 +#else +#define HWY_ARCH_PPC_64 0 +#endif + // aarch32 is currently not supported; please raise an issue if you want it. #if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64) #define HWY_ARCH_ARM_A64 1 @@ -225,18 +267,52 @@ #endif #ifdef __riscv -#define HWY_ARCH_RVV 1 +#define HWY_ARCH_RISCV 1 +#else +#define HWY_ARCH_RISCV 0 +#endif +// DEPRECATED names; please use HWY_ARCH_RISCV instead. +#define HWY_ARCH_RVV HWY_ARCH_RISCV + +#if HWY_ARCH_RISCV && defined(__riscv_xlen) + +#if __riscv_xlen == 32 +#define HWY_ARCH_RISCV_32 1 +#else +#define HWY_ARCH_RISCV_32 0 +#endif + +#if __riscv_xlen == 64 +#define HWY_ARCH_RISCV_64 1 #else -#define HWY_ARCH_RVV 0 +#define HWY_ARCH_RISCV_64 0 +#endif + +#else // !HWY_ARCH_RISCV || !defined(__riscv_xlen) +#define HWY_ARCH_RISCV_32 0 +#define HWY_ARCH_RISCV_64 0 +#endif // HWY_ARCH_RISCV && defined(__riscv_xlen) + +#if HWY_ARCH_RISCV_32 && HWY_ARCH_RISCV_64 +#error "Cannot have both RISCV_32 and RISCV_64" +#endif + +#if defined(__s390x__) +#define HWY_ARCH_S390X 1 +#else +#define HWY_ARCH_S390X 0 #endif // It is an error to detect multiple architectures at the same time, but OK to // detect none of the above. #if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \ - HWY_ARCH_WASM + HWY_ARCH_RVV) > 1 + HWY_ARCH_WASM + HWY_ARCH_RISCV + HWY_ARCH_S390X) > 1 #error "Must not detect more than one architecture" #endif +//------------------------------------------------------------------------------ +// Operating system + #if defined(_WIN32) || defined(_WIN64) #define HWY_OS_WIN 1 #else @@ -249,6 +325,25 @@ #define HWY_OS_LINUX 0 #endif +// iOS or Mac +#if defined(__APPLE__) +#define HWY_OS_APPLE 1 +#else +#define HWY_OS_APPLE 0 +#endif + +#if defined(__FreeBSD__) +#define HWY_OS_FREEBSD 1 +#else +#define HWY_OS_FREEBSD 0 +#endif + +// It is an error to detect multiple OSes at the same time, but OK to +// detect none of the above. +#if (HWY_OS_WIN + HWY_OS_LINUX + HWY_OS_APPLE + HWY_OS_FREEBSD) > 1 +#error "Must not detect more than one OS" +#endif + //------------------------------------------------------------------------------ // Endianness diff --git a/r/src/vendor/highway/hwy/detect_targets.h b/r/src/vendor/highway/hwy/detect_targets.h index ccab425a..d5f3ab07 100644 --- a/r/src/vendor/highway/hwy/detect_targets.h +++ b/r/src/vendor/highway/hwy/detect_targets.h @@ -62,7 +62,8 @@ // Bits 0..3 reserved (4 targets) #define HWY_AVX3_SPR (1LL << 4) // Bit 5 reserved (likely AVX10.2 with 256-bit vectors) -// Currently HWY_AVX3_DL plus a special case for CompressStore (10x as fast). +// Currently HWY_AVX3_DL plus AVX512BF16 and a special case for CompressStore +// (10x as fast). // We may later also use VPCONFLICT. #define HWY_AVX3_ZEN4 (1LL << 6) // see HWY_WANT_AVX3_ZEN4 below @@ -84,15 +85,22 @@ #define HWY_HIGHEST_TARGET_BIT_X86 14 // --------------------------- Arm: 15 targets (+ one fallback) -// Bits 15..23 reserved (9 targets) -#define HWY_SVE2_128 (1LL << 24) // specialized target (e.g. Arm N2) -#define HWY_SVE_256 (1LL << 25) // specialized target (e.g. Arm V1) -#define HWY_SVE2 (1LL << 26) -#define HWY_SVE (1LL << 27) +// Bits 15..17 reserved (3 targets) +#define HWY_SVE2_128 (1LL << 18) // specialized (e.g. Neoverse V2/N2/N3) +#define HWY_SVE_256 (1LL << 19) // specialized (Neoverse V1) +// Bits 20-22 reserved for later SVE (3 targets) +#define HWY_SVE2 (1LL << 23) +#define HWY_SVE (1LL << 24) +// Bit 25 reserved for NEON +#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2/N3) +// Bit 27 reserved for NEON #define HWY_NEON (1LL << 28) // Implies support for AES #define HWY_NEON_WITHOUT_AES (1LL << 29) #define HWY_HIGHEST_TARGET_BIT_ARM 29 +#define HWY_ALL_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON | HWY_NEON_BF16) +#define HWY_ALL_SVE (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128) + // --------------------------- RISC-V: 9 targets (+ one fallback) // Bits 30..36 reserved (7 targets) #define HWY_RVV (1LL << 37) @@ -102,14 +110,17 @@ // --------------------------- Future expansion: 4 targets // Bits 39..42 reserved -// --------------------------- IBM Power: 9 targets (+ one fallback) +// --------------------------- IBM Power/ZSeries: 9 targets (+ one fallback) // Bits 43..46 reserved (4 targets) #define HWY_PPC10 (1LL << 47) // v3.1 #define HWY_PPC9 (1LL << 48) // v3.0 #define HWY_PPC8 (1LL << 49) // v2.07 -// Bits 50..51 reserved for prior VSX/AltiVec (2 targets) +#define HWY_Z15 (1LL << 50) // Z15 +#define HWY_Z14 (1LL << 51) // Z14 #define HWY_HIGHEST_TARGET_BIT_PPC 51 +#define HWY_ALL_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) + // --------------------------- WebAssembly: 9 targets (+ one fallback) // Bits 52..57 reserved (6 targets) #define HWY_WASM_EMU256 (1LL << 58) // Experimental @@ -187,7 +198,7 @@ // armv7be has not been tested and is not yet supported. #if HWY_ARCH_ARM_V7 && HWY_IS_BIG_ENDIAN -#define HWY_BROKEN_ARM7_BIG_ENDIAN (HWY_NEON | HWY_NEON_WITHOUT_AES) +#define HWY_BROKEN_ARM7_BIG_ENDIAN HWY_ALL_NEON #else #define HWY_BROKEN_ARM7_BIG_ENDIAN 0 #endif @@ -198,14 +209,26 @@ #if HWY_ARCH_ARM_V7 && (__ARM_ARCH_PROFILE == 'A') && \ !defined(__ARM_VFPV4__) && \ !((__ARM_NEON_FP & 0x2 /* half-float */) && (__ARM_FEATURE_FMA == 1)) -#define HWY_BROKEN_ARM7_WITHOUT_VFP4 (HWY_NEON | HWY_NEON_WITHOUT_AES) +#define HWY_BROKEN_ARM7_WITHOUT_VFP4 HWY_ALL_NEON #else #define HWY_BROKEN_ARM7_WITHOUT_VFP4 0 #endif +// HWY_NEON_BF16 requires recent compilers. +#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1700) || \ + (HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 1302) +#define HWY_BROKEN_NEON_BF16 (HWY_NEON_BF16) +#else +#define HWY_BROKEN_NEON_BF16 0 +#endif + // SVE[2] require recent clang or gcc versions. -#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100) || \ - (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) + +// In addition, SVE[2] is not currently supported by any Apple CPU (at least up +// to and including M4 and A18). +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1900) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + HWY_OS_APPLE #define HWY_BROKEN_SVE (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128) #else #define HWY_BROKEN_SVE 0 @@ -239,6 +262,22 @@ #define HWY_BROKEN_PPC10 0 #endif +// PPC8/PPC9/PPC10 targets may fail to compile on 32-bit PowerPC +#if HWY_ARCH_PPC && !HWY_ARCH_PPC_64 +#define HWY_BROKEN_PPC_32BIT (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) +#else +#define HWY_BROKEN_PPC_32BIT 0 +#endif + +// HWY_RVV fails to compile with GCC < 13 or Clang < 16. +#if HWY_ARCH_RISCV && \ + ((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300)) +#define HWY_BROKEN_RVV (HWY_RVV) +#else +#define HWY_BROKEN_RVV 0 +#endif + // Allow the user to override this without any guarantee of success. #ifndef HWY_BROKEN_TARGETS @@ -246,7 +285,8 @@ (HWY_BROKEN_CLANG6 | HWY_BROKEN_32BIT | HWY_BROKEN_MSVC | \ HWY_BROKEN_AVX3_DL_ZEN4 | HWY_BROKEN_AVX3_SPR | \ HWY_BROKEN_ARM7_BIG_ENDIAN | HWY_BROKEN_ARM7_WITHOUT_VFP4 | \ - HWY_BROKEN_SVE | HWY_BROKEN_PPC10) + HWY_BROKEN_NEON_BF16 | HWY_BROKEN_SVE | HWY_BROKEN_PPC10 | \ + HWY_BROKEN_PPC_32BIT | HWY_BROKEN_RVV) #endif // HWY_BROKEN_TARGETS @@ -316,13 +356,28 @@ #define HWY_BASELINE_PPC10 0 #endif +#if HWY_ARCH_S390X && defined(__VEC__) && defined(__ARCH__) && __ARCH__ >= 12 +#define HWY_BASELINE_Z14 HWY_Z14 +#else +#define HWY_BASELINE_Z14 0 +#endif + +#if HWY_BASELINE_Z14 && __ARCH__ >= 13 +#define HWY_BASELINE_Z15 HWY_Z15 +#else +#define HWY_BASELINE_Z15 0 +#endif + #define HWY_BASELINE_SVE2 0 #define HWY_BASELINE_SVE 0 #define HWY_BASELINE_NEON 0 #if HWY_ARCH_ARM -#if defined(__ARM_FEATURE_SVE2) +// Also check compiler version as done for HWY_ATTAINABLE_SVE2 because the +// static target (influenced here) must be one of the attainable targets. +#if defined(__ARM_FEATURE_SVE2) && \ + (HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) #undef HWY_BASELINE_SVE2 // was 0, will be re-defined // If user specified -msve-vector-bits=128, they assert the vector length is // 128 bits and we should use the HWY_SVE2_128 (more efficient for some ops). @@ -337,7 +392,8 @@ #endif // __ARM_FEATURE_SVE_BITS #endif // __ARM_FEATURE_SVE2 -#if defined(__ARM_FEATURE_SVE) +#if defined(__ARM_FEATURE_SVE) && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) #undef HWY_BASELINE_SVE // was 0, will be re-defined // See above. If user-specified vector length matches our optimization, use it. #if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 256 @@ -350,12 +406,17 @@ // GCC 4.5.4 only defines __ARM_NEON__; 5.4 defines both. #if defined(__ARM_NEON__) || defined(__ARM_NEON) #undef HWY_BASELINE_NEON -#if defined(__ARM_FEATURE_AES) -#define HWY_BASELINE_NEON (HWY_NEON | HWY_NEON_WITHOUT_AES) +#if defined(__ARM_FEATURE_AES) && \ + defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + defined(__ARM_FEATURE_DOTPROD) && \ + defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#define HWY_BASELINE_NEON HWY_ALL_NEON +#elif defined(__ARM_FEATURE_AES) +#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON) #else #define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES) -#endif -#endif +#endif // __ARM_FEATURE* +#endif // __ARM_NEON #endif // HWY_ARCH_ARM @@ -483,14 +544,16 @@ #define HWY_BASELINE_AVX3_ZEN4 0 #endif -#if HWY_BASELINE_AVX3_DL != 0 && defined(__AVX512FP16__) +#if HWY_BASELINE_AVX3_DL != 0 && defined(__AVX512BF16__) && \ + defined(__AVX512FP16__) #define HWY_BASELINE_AVX3_SPR HWY_AVX3_SPR #else #define HWY_BASELINE_AVX3_SPR 0 #endif // RVV requires intrinsics 0.11 or later, see #1156. -#if HWY_ARCH_RVV && defined(__riscv_v_intrinsic) && __riscv_v_intrinsic >= 11000 +#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 #define HWY_BASELINE_RVV HWY_RVV #else #define HWY_BASELINE_RVV 0 @@ -498,13 +561,14 @@ // Allow the user to override this without any guarantee of success. #ifndef HWY_BASELINE_TARGETS -#define HWY_BASELINE_TARGETS \ - (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \ - HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_SVE2 | \ - HWY_BASELINE_SVE | HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | \ - HWY_BASELINE_SSSE3 | HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | \ - HWY_BASELINE_AVX3 | HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | \ - HWY_BASELINE_AVX3_SPR | HWY_BASELINE_RVV) +#define HWY_BASELINE_TARGETS \ + (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \ + HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_Z14 | \ + HWY_BASELINE_Z15 | HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | \ + HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | HWY_BASELINE_SSSE3 | \ + HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \ + HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | HWY_BASELINE_AVX3_SPR | \ + HWY_BASELINE_RVV) #endif // HWY_BASELINE_TARGETS //------------------------------------------------------------------------------ @@ -534,17 +598,66 @@ #endif // Defining one of HWY_COMPILE_ONLY_* will trump HWY_COMPILE_ALL_ATTAINABLE. -// Clang, GCC and MSVC allow runtime dispatch on x86. -#if HWY_ARCH_X86 -#define HWY_HAVE_RUNTIME_DISPATCH 1 -// On Arm/PPC, currently only GCC does, and we require Linux to detect CPU -// capabilities. -#elif (HWY_ARCH_ARM || HWY_ARCH_PPC) && HWY_COMPILER_GCC_ACTUAL && \ - HWY_OS_LINUX && !defined(TOOLCHAIN_MISS_SYS_AUXV_H) +#ifndef HWY_HAVE_AUXV // allow override +#ifdef TOOLCHAIN_MISS_SYS_AUXV_H +#define HWY_HAVE_AUXV 0 // CMake failed to find the header +// glibc 2.16 added auxv, but checking for that requires features.h, and we do +// not want to include system headers here. Instead check for the header +// directly, which has been supported at least since GCC 5.4 and Clang 3. +#elif defined(__has_include) // note: wrapper macro fails on Clang ~17 +// clang-format off +#if __has_include() +// clang-format on +#define HWY_HAVE_AUXV 1 // header present +#else +#define HWY_HAVE_AUXV 0 // header not present +#endif // __has_include +#else // compiler lacks __has_include +#define HWY_HAVE_AUXV 0 +#endif +#endif // HWY_HAVE_AUXV + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_RVV // allow override +// The riscv_vector.h in Clang 16-18 requires compiler flags, and 19 still has +// some missing intrinsics, see +// https://github.com/llvm/llvm-project/issues/56592. GCC 13.3 also has an +// #error check, whereas 14.1 fails with "argument type 'vuint16m8_t' requires +// the V ISA extension": https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115325. +#if HWY_ARCH_RISCV && HWY_COMPILER_CLANG >= 1900 && 0 +#define HWY_HAVE_RUNTIME_DISPATCH_RVV 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_RVV 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_RVV + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_APPLE // allow override +#if HWY_ARCH_ARM_A64 && HWY_OS_APPLE && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) +#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_APPLE + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_LINUX // allow override +#if (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X) && HWY_OS_LINUX && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) && HWY_HAVE_AUXV +#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_LINUX + +// Allow opting out, and without a guarantee of success, opting-in. +#ifndef HWY_HAVE_RUNTIME_DISPATCH +// Clang, GCC and MSVC allow OS-independent runtime dispatch on x86. +#if HWY_ARCH_X86 || HWY_HAVE_RUNTIME_DISPATCH_RVV || \ + HWY_HAVE_RUNTIME_DISPATCH_APPLE || HWY_HAVE_RUNTIME_DISPATCH_LINUX #define HWY_HAVE_RUNTIME_DISPATCH 1 #else #define HWY_HAVE_RUNTIME_DISPATCH 0 #endif +#endif // HWY_HAVE_RUNTIME_DISPATCH // AVX3_DL is not widely available yet. To reduce code size and compile time, // only include it in the set of attainable targets (for dynamic dispatch) if @@ -556,22 +669,26 @@ #endif #if HWY_ARCH_ARM_A64 && HWY_HAVE_RUNTIME_DISPATCH -#define HWY_ATTAINABLE_NEON (HWY_NEON | HWY_NEON_WITHOUT_AES) +#define HWY_ATTAINABLE_NEON HWY_ALL_NEON #elif HWY_ARCH_ARM // static dispatch, or HWY_ARCH_ARM_V7 #define HWY_ATTAINABLE_NEON (HWY_BASELINE_NEON) #else #define HWY_ATTAINABLE_NEON 0 #endif -#if HWY_ARCH_ARM_A64 && (HWY_HAVE_RUNTIME_DISPATCH || \ - (HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256))) +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) && \ + (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256))) #define HWY_ATTAINABLE_SVE (HWY_SVE | HWY_SVE_256) #else #define HWY_ATTAINABLE_SVE 0 #endif -#if HWY_ARCH_ARM_A64 && (HWY_HAVE_RUNTIME_DISPATCH || \ - (HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128))) +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) && \ + (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128))) #define HWY_ATTAINABLE_SVE2 (HWY_SVE2 | HWY_SVE2_128) #else #define HWY_ATTAINABLE_SVE2 0 @@ -579,18 +696,51 @@ #if HWY_ARCH_PPC && defined(__ALTIVEC__) && \ (!HWY_COMPILER_CLANG || HWY_BASELINE_PPC8 != 0) + +#if (HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10) && \ + !defined(HWY_SKIP_NON_BEST_BASELINE) +// On POWER with -m flags, we get compile errors (#1707) for targets older than +// the baseline specified via -m, so only generate the static target and better. +// Note that some Linux distros actually do set POWER9 as the baseline. +// This works by skipping case 3 below, so case 4 is reached. +#define HWY_SKIP_NON_BEST_BASELINE +#endif + #define HWY_ATTAINABLE_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) + #else #define HWY_ATTAINABLE_PPC 0 #endif -// Attainable means enabled and the compiler allows intrinsics (even when not -// allowed to autovectorize). Used in 3 and 4. -#if HWY_ARCH_X86 -#define HWY_ATTAINABLE_TARGETS \ +#if HWY_ARCH_S390X && HWY_BASELINE_Z14 != 0 +#define HWY_ATTAINABLE_S390X (HWY_Z14 | HWY_Z15) +#else +#define HWY_ATTAINABLE_S390X 0 +#endif + +#if HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_RISCV HWY_RVV +#else +#define HWY_ATTAINABLE_RISCV HWY_BASELINE_RVV +#endif + +#ifndef HWY_ATTAINABLE_TARGETS_X86 // allow override +#if HWY_COMPILER_MSVC && defined(HWY_SLOW_MSVC) +// Fewer targets for faster builds. +#define HWY_ATTAINABLE_TARGETS_X86 \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_STATIC_TARGET | HWY_AVX2) +#else // !HWY_COMPILER_MSVC +#define HWY_ATTAINABLE_TARGETS_X86 \ HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | \ HWY_AVX2 | HWY_AVX3 | HWY_ATTAINABLE_AVX3_DL | HWY_AVX3_ZEN4 | \ HWY_AVX3_SPR) +#endif // !HWY_COMPILER_MSVC +#endif // HWY_ATTAINABLE_TARGETS_X86 + +// Attainable means enabled and the compiler allows intrinsics (even when not +// allowed to autovectorize). Used in 3 and 4. +#if HWY_ARCH_X86 +#define HWY_ATTAINABLE_TARGETS HWY_ATTAINABLE_TARGETS_X86 #elif HWY_ARCH_ARM #define HWY_ATTAINABLE_TARGETS \ HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_NEON | HWY_ATTAINABLE_SVE | \ @@ -598,6 +748,12 @@ #elif HWY_ARCH_PPC #define HWY_ATTAINABLE_TARGETS \ HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_PPC) +#elif HWY_ARCH_S390X +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_S390X) +#elif HWY_ARCH_RISCV +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_RISCV) #else #define HWY_ATTAINABLE_TARGETS (HWY_ENABLED_BASELINE) #endif // HWY_ARCH_* @@ -621,7 +777,8 @@ #define HWY_TARGETS HWY_STATIC_TARGET // 3) For tests: include all attainable targets (in particular: scalar) -#elif defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST) +#elif (defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST)) && \ + !defined(HWY_SKIP_NON_BEST_BASELINE) #define HWY_TARGETS HWY_ATTAINABLE_TARGETS // 4) Default: attainable WITHOUT non-best baseline. This reduces code size by diff --git a/r/src/vendor/highway/hwy/foreach_target.h b/r/src/vendor/highway/hwy/foreach_target.h index ca3e5a24..5219aee2 100644 --- a/r/src/vendor/highway/hwy/foreach_target.h +++ b/r/src/vendor/highway/hwy/foreach_target.h @@ -168,6 +168,17 @@ #endif #endif +#if (HWY_TARGETS & HWY_NEON_BF16) && (HWY_STATIC_TARGET != HWY_NEON_BF16) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON_BF16 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + #if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE) #undef HWY_TARGET #define HWY_TARGET HWY_SVE @@ -271,7 +282,31 @@ #endif #endif -// ------------------------------ HWY_ARCH_RVV +// ------------------------------ HWY_ARCH_S390X + +#if (HWY_TARGETS & HWY_Z14) && (HWY_STATIC_TARGET != HWY_Z14) +#undef HWY_TARGET +#define HWY_TARGET HWY_Z14 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_Z15) && (HWY_STATIC_TARGET != HWY_Z15) +#undef HWY_TARGET +#define HWY_TARGET HWY_Z15 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_RISCV #if (HWY_TARGETS & HWY_RVV) && (HWY_STATIC_TARGET != HWY_RVV) #undef HWY_TARGET diff --git a/r/src/vendor/highway/hwy/highway.h b/r/src/vendor/highway/hwy/highway.h index 6d4a5d78..48359ea4 100644 --- a/r/src/vendor/highway/hwy/highway.h +++ b/r/src/vendor/highway/hwy/highway.h @@ -18,10 +18,17 @@ // IWYU pragma: begin_exports #include "hwy/base.h" #include "hwy/detect_compiler_arch.h" +#include "hwy/detect_targets.h" #include "hwy/highway_export.h" #include "hwy/targets.h" // IWYU pragma: end_exports +#if HWY_CXX_LANG < 201703L +#define HWY_DISPATCH_MAP 1 +#else +#define HWY_DISPATCH_MAP 0 +#endif + // This include guard is checked by foreach_target, so avoid the usual _H_ // suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included // after/outside this include guard. @@ -30,11 +37,6 @@ namespace hwy { -// API version (https://semver.org/); keep in sync with CMakeLists.txt. -#define HWY_MAJOR 1 -#define HWY_MINOR 0 -#define HWY_PATCH 7 - //------------------------------------------------------------------------------ // Shorthand for tags (defined in shared-inl.h) used to select overloads. // Note that ScalableTag is preferred over HWY_FULL, and CappedTag over @@ -84,6 +86,8 @@ namespace hwy { #define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_WITHOUT_AES::FUNC_NAME #elif HWY_STATIC_TARGET == HWY_NEON #define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON_BF16 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_BF16::FUNC_NAME #elif HWY_STATIC_TARGET == HWY_SVE #define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE::FUNC_NAME #elif HWY_STATIC_TARGET == HWY_SVE2 @@ -98,6 +102,10 @@ namespace hwy { #define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC9::FUNC_NAME #elif HWY_STATIC_TARGET == HWY_PPC10 #define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC10::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_Z14 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z14::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_Z15 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z15::FUNC_NAME #elif HWY_STATIC_TARGET == HWY_SSE2 #define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE2::FUNC_NAME #elif HWY_STATIC_TARGET == HWY_SSSE3 @@ -158,6 +166,12 @@ namespace hwy { #define HWY_CHOOSE_NEON(FUNC_NAME) nullptr #endif +#if HWY_TARGETS & HWY_NEON_BF16 +#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) &N_NEON_BF16::FUNC_NAME +#else +#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) nullptr +#endif + #if HWY_TARGETS & HWY_SVE #define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME #else @@ -200,6 +214,18 @@ namespace hwy { #define HWY_CHOOSE_PPC10(FUNC_NAME) nullptr #endif +#if HWY_TARGETS & HWY_Z14 +#define HWY_CHOOSE_Z14(FUNC_NAME) &N_Z14::FUNC_NAME +#else +#define HWY_CHOOSE_Z14(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_Z15 +#define HWY_CHOOSE_Z15(FUNC_NAME) &N_Z15::FUNC_NAME +#else +#define HWY_CHOOSE_Z15(FUNC_NAME) nullptr +#endif + #if HWY_TARGETS & HWY_SSE2 #define HWY_CHOOSE_SSE2(FUNC_NAME) &N_SSE2::FUNC_NAME #else @@ -252,41 +278,68 @@ namespace hwy { // apparently cannot be an array. Use a function pointer instead, which has the // disadvantage that we call the static (not best) target on the first call to // any HWY_DYNAMIC_DISPATCH. -#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915 +#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) #define HWY_DISPATCH_WORKAROUND 1 #else #define HWY_DISPATCH_WORKAROUND 0 #endif +#if HWY_DISPATCH_MAP +struct AllExports { + template + static const FuncPtr*& GetRefToExportsPtr() { + static const FuncPtr* s_exports = nullptr; + return s_exports; + } +}; +#endif + // Provides a static member function which is what is called during the first // HWY_DYNAMIC_DISPATCH, where GetIndex is still zero, and instantiations of -// this function are the first entry in the tables created by HWY_EXPORT. +// this function are the first entry in the tables created by HWY_EXPORT[_T]. template struct FunctionCache { public: - typedef RetType(FunctionType)(Args...); + typedef RetType(FuncType)(Args...); + using FuncPtr = FuncType*; -#if HWY_DISPATCH_WORKAROUND - template - static RetType ChooseAndCall(Args... args) { - ChosenTarget& chosen_target = GetChosenTarget(); - chosen_target.Update(SupportedTargets()); - return (*func)(args...); - } -#else // A template function that when instantiated has the same signature as the // function being called. This function initializes the bit array of targets // supported by the current CPU and then calls the appropriate entry within // the HWY_EXPORT table. Subsequent calls via HWY_DYNAMIC_DISPATCH to any // exported functions, even those defined by different translation units, // will dispatch directly to the best available target. - template +#if HWY_DISPATCH_MAP + template static RetType ChooseAndCall(Args... args) { ChosenTarget& chosen_target = GetChosenTarget(); chosen_target.Update(SupportedTargets()); + + const FuncPtr* table = AllExports::template GetRefToExportsPtr< + FuncPtr, RemoveCvRef, kHash>(); + HWY_ASSERT(table); + + return (table[chosen_target.GetIndex()])(args...); + } + +#if !HWY_DISPATCH_WORKAROUND + template + static RetType TableChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); return (table[chosen_target.GetIndex()])(args...); } -#endif // HWY_DISPATCH_WORKAROUND +#endif // !HWY_DISPATCH_WORKAROUND + +#else // !HWY_DISPATCH_MAP: zero-overhead, but requires C++17 + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // HWY_DISPATCH_MAP }; // Used to deduce the template parameters RetType and Args from a function. @@ -299,9 +352,7 @@ FunctionCache DeduceFunctionCache(RetType (*)(Args...)) { HWY_CONCAT(FUNC_NAME, HighwayDispatchTable) // HWY_EXPORT(FUNC_NAME); expands to a static array that is used by -// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. This -// static array must be defined at the same namespace level as the function -// it is exporting. +// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. // After being exported, it can be called from other parts of the same source // file using HWY_DYNAMIC_DISPATCH(), in particular from a function wrapper // like in the following example: @@ -326,59 +377,181 @@ FunctionCache DeduceFunctionCache(RetType (*)(Args...)) { // } // } // namespace skeleton // +// For templated code with a single type parameter, instead use HWY_EXPORT_T and +// its HWY_DYNAMIC_DISPATCH_T counterpart: +// +// template +// void MyFunctionCaller(T ...) { +// // First argument to both HWY_EXPORT_T and HWY_DYNAMIC_DISPATCH_T is an +// // arbitrary table name; you must provide the same name for each call. +// // It is fine to have multiple HWY_EXPORT_T in a function, but a 64-bit +// // FNV hash collision among *any* table names will trigger HWY_ABORT. +// HWY_EXPORT_T(Table1, MyFunction) +// HWY_DYNAMIC_DISPATCH_T(Table1)(a, b, c); +// } +// +// Note that HWY_EXPORT_T must be invoked inside a template (in the above +// example: `MyFunctionCaller`), so that a separate table will be created for +// each template instantiation. For convenience, we also provide a macro that +// combines both steps and avoids the need to pick a table name: +// +// template +// void MyFunctionCaller(T ...) { +// // Table name is automatically chosen. Note that this variant must be +// // called in statement context; it is not a valid expression. +// HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(MyFunction)(a, b, c); +// } +// Simplified version for IDE or the dynamic dispatch case with only one target. #if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) -// Simplified version for IDE or the dynamic dispatch case with only one target. -// This case still uses a table, although of a single element, to provide the -// same compile error conditions as with the dynamic dispatch case when multiple -// targets are being compiled. -#define HWY_EXPORT(FUNC_NAME) \ +// We use a table to provide the same compile error conditions as with the +// non-simplified case, but the table only has a single entry. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const \ - HWY_DISPATCH_TABLE(FUNC_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)} -#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) + HWY_DISPATCH_TABLE(TABLE_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)} + +// Use the table, not just STATIC_DISPATCH as in DYNAMIC_DISPATCH, because +// TABLE_NAME might not match the function name. +#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) (HWY_DISPATCH_TABLE(TABLE_NAME)[0]) +#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) \ + (*(HWY_DYNAMIC_POINTER_T(TABLE_NAME))) + +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) #define HWY_DYNAMIC_POINTER(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) -#else +#else // not simplified: full table + +// Pre-C++17 workaround: non-type template arguments must have linkage, which +// means we cannot pass &table as a template argument to ChooseAndCall. +// ChooseAndCall must find a way to access the table in order to dispatch to the +// chosen target: +// 0) Skipping this by dispatching to the static target would be surprising to +// users and may have serious performance implications. +// 1) An extra function parameter would be unacceptable because it changes the +// user-visible function signature. +// 2) Declaring a table, then defining a pointer to it would work, but requires +// an additional DECLARE step outside the function so that the pointer has +// linkage, which breaks existing code. +// 3) We instead associate the function with the table using an instance of an +// unnamed struct and the hash of the table name as the key. Because +// ChooseAndCall has the type information, it can then cast to the function +// pointer type. However, we cannot simply pass the name as a template +// argument to ChooseAndCall because this requires char*, which hits the same +// linkage problem. We instead hash the table name, which assumes the +// function names do not have collisions. +#if HWY_DISPATCH_MAP + +static constexpr uint64_t FNV(const char* name) { + return *name ? static_cast(static_cast(*name)) ^ + (0x100000001b3ULL * FNV(name + 1)) + : 0xcbf29ce484222325ULL; +} -// Simplified version for MSVC 2017: function pointer instead of table. -#if HWY_DISPATCH_WORKAROUND +template +struct AddExport { + template + AddExport(ExportsKey /*exports_key*/, const char* table_name, + const FuncPtr* table) { + using FuncCache = decltype(DeduceFunctionCache(hwy::DeclVal())); + static_assert( + hwy::IsSame, typename FuncCache::FuncPtr>(), + "FuncPtr should be same type as FuncCache::FuncPtr"); + + const FuncPtr*& exports_ptr = AllExports::template GetRefToExportsPtr< + RemoveCvRef, RemoveCvRef, kHash>(); + if (exports_ptr && exports_ptr != table) { + HWY_ABORT("Hash collision for %s, rename the function\n", table_name); + } else { + exports_ptr = table; + } + } +}; +// Dynamic dispatch: defines table of function pointers. This must be invoked +// from inside the function template that calls the template we are exporting. +// TABLE_NAME must match the one passed to HWY_DYNAMIC_DISPATCH_T. This +// argument allows multiple exports within one function. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + static const struct { \ + } HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey) = {}; \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + TABLE_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + }; \ + HWY_MAYBE_UNUSED static hwy::AddExport HWY_CONCAT( \ + HighwayAddTable, __LINE__)( \ + HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey), #TABLE_NAME, \ + HWY_DISPATCH_TABLE(TABLE_NAME)) + +// For non-template functions. Not necessarily invoked within a function, hence +// we derive the string and variable names from FUNC_NAME, not HWY_FUNCTION. +#if HWY_DISPATCH_WORKAROUND +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) +#else #define HWY_EXPORT(FUNC_NAME) \ static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ - FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + FUNC_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ /* The first entry in the table initializes the global cache and \ - * calls the function from HWY_STATIC_TARGET. */ \ - &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH( \ - FUNC_NAME)))::ChooseAndCall<&HWY_STATIC_DISPATCH(FUNC_NAME)>, \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template TableChooseAndCall, \ HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ HWY_CHOOSE_FALLBACK(FUNC_NAME), \ } +#endif // HWY_DISPATCH_WORKAROUND -#else +#else // !HWY_DISPATCH_MAP -// Dynamic dispatch case with one entry per dynamic target plus the fallback -// target and the initialization wrapper. -#define HWY_EXPORT(FUNC_NAME) \ +// Zero-overhead, but requires C++17 for non-type template arguments without +// linkage, because HWY_EXPORT_T tables are local static variables. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ - FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + TABLE_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ /* The first entry in the table initializes the global cache and \ * calls the appropriate function. */ \ - &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH( \ - FUNC_NAME)))::ChooseAndCall, \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template ChooseAndCall, \ HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ HWY_CHOOSE_FALLBACK(FUNC_NAME), \ } -#endif // HWY_DISPATCH_WORKAROUND +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) + +#endif // HWY_DISPATCH_MAP + +// HWY_DISPATCH_MAP only affects how tables are created, not their usage. -#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \ - (*(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()])) +// Evaluates to the function pointer for the chosen target. #define HWY_DYNAMIC_POINTER(FUNC_NAME) \ (HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()]) +// Calls the function pointer for the chosen target. +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) (*(HWY_DYNAMIC_POINTER(FUNC_NAME))) + +// Same as DISPATCH, but provide a different arg name to clarify usage. +#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) HWY_DYNAMIC_DISPATCH(TABLE_NAME) +#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) HWY_DYNAMIC_POINTER(TABLE_NAME) + #endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) +// Returns the name of an anonymous dispatch table that is only shared with +// macro invocations coming from the same source line. +#define HWY_DISPATCH_TABLE_T() HWY_CONCAT(HighwayDispatchTableT, __LINE__) + +// For templated code, combines export and dispatch using an anonymous table. +#define HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC_NAME) \ + HWY_EXPORT_T(HWY_DISPATCH_TABLE_T(), FUNC_NAME); \ + HWY_DYNAMIC_DISPATCH_T(HWY_DISPATCH_TABLE_T()) + // DEPRECATED names; please use HWY_HAVE_* instead. #define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64 #define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16 @@ -408,13 +581,12 @@ FunctionCache DeduceFunctionCache(RetType (*)(Args...)) { #elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \ HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR #include "hwy/ops/x86_512-inl.h" -#elif HWY_TARGET == HWY_PPC8 || HWY_TARGET == HWY_PPC9 || \ - HWY_TARGET == HWY_PPC10 +#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 || \ + (HWY_TARGET & HWY_ALL_PPC) #include "hwy/ops/ppc_vsx-inl.h" -#elif HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_NEON_WITHOUT_AES +#elif HWY_TARGET & HWY_ALL_NEON #include "hwy/ops/arm_neon-inl.h" -#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || \ - HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#elif HWY_TARGET & HWY_ALL_SVE #include "hwy/ops/arm_sve-inl.h" #elif HWY_TARGET == HWY_WASM_EMU256 #include "hwy/ops/wasm_256-inl.h" diff --git a/r/src/vendor/highway/hwy/nanobenchmark.cc b/r/src/vendor/highway/hwy/nanobenchmark.cc index ea5549f3..0dec0bc4 100644 --- a/r/src/vendor/highway/hwy/nanobenchmark.cc +++ b/r/src/vendor/highway/hwy/nanobenchmark.cc @@ -24,6 +24,7 @@ #include #include +#include "hwy/base.h" #include "hwy/robust_statistics.h" #include "hwy/timer-inl.h" #include "hwy/timer.h" @@ -76,7 +77,9 @@ timer::Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad, // For "few" (depends also on the variance) samples, Median is safer. est = robust_statistics::Median(samples.data(), samples.size()); } - NANOBENCHMARK_CHECK(est != 0); + if (est == 0) { + HWY_WARN("estimated duration is 0\n"); + } // Median absolute deviation (mad) is a robust measure of 'variability'. const timer::Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation( @@ -194,9 +197,9 @@ void FillSubset(const InputVec& full, const FuncInput input_to_skip, (*subset)[idx_subset++] = next; } } - NANOBENCHMARK_CHECK(idx_subset == subset->size()); - NANOBENCHMARK_CHECK(idx_omit == omit.size()); - NANOBENCHMARK_CHECK(occurrence == count - 1); + HWY_DASSERT(idx_subset == subset->size()); + HWY_DASSERT(idx_omit == omit.size()); + HWY_DASSERT(occurrence == count - 1); } // Returns total ticks elapsed for all inputs. @@ -239,12 +242,11 @@ HWY_DLLEXPORT int Unpredictable1() { return timer::Start() != ~0ULL; } HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, const FuncInput* inputs, const size_t num_inputs, Result* results, const Params& p) { - NANOBENCHMARK_CHECK(num_inputs != 0); + HWY_DASSERT(num_inputs != 0); char cpu100[100]; if (!platform::HaveTimerStop(cpu100)) { - fprintf(stderr, "CPU '%s' does not support RDTSCP, skipping benchmark.\n", - cpu100); + HWY_WARN("CPU '%s' does not support RDTSCP, skipping benchmark.\n", cpu100); return 0; } @@ -262,8 +264,8 @@ HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, const timer::Ticks overhead = Overhead(arg, &full, p); const timer::Ticks overhead_skip = Overhead(arg, &subset, p); if (overhead < overhead_skip) { - fprintf(stderr, "Measurement failed: overhead %d < %d\n", - static_cast(overhead), static_cast(overhead_skip)); + HWY_WARN("Measurement failed: overhead %d < %d\n", + static_cast(overhead), static_cast(overhead_skip)); return 0; } @@ -282,8 +284,8 @@ HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, TotalDuration(func, arg, &subset, p, &max_rel_mad); if (total < total_skip) { - fprintf(stderr, "Measurement failed: total %f < %f\n", - static_cast(total), static_cast(total_skip)); + HWY_WARN("Measurement failed: total %f < %f\n", + static_cast(total), static_cast(total_skip)); return 0; } diff --git a/r/src/vendor/highway/hwy/nanobenchmark.h b/r/src/vendor/highway/hwy/nanobenchmark.h index 46bfc4b0..eefe6fb7 100644 --- a/r/src/vendor/highway/hwy/nanobenchmark.h +++ b/r/src/vendor/highway/hwy/nanobenchmark.h @@ -49,25 +49,7 @@ #include #include "hwy/highway_export.h" -#include "hwy/timer.h" - -// Enables sanity checks that verify correct operation at the cost of -// longer benchmark runs. -#ifndef NANOBENCHMARK_ENABLE_CHECKS -#define NANOBENCHMARK_ENABLE_CHECKS 0 -#endif - -#define NANOBENCHMARK_CHECK_ALWAYS(condition) \ - while (!(condition)) { \ - fprintf(stderr, "Nanobenchmark check failed at line %d\n", __LINE__); \ - abort(); \ - } - -#if NANOBENCHMARK_ENABLE_CHECKS -#define NANOBENCHMARK_CHECK(condition) NANOBENCHMARK_CHECK_ALWAYS(condition) -#else -#define NANOBENCHMARK_CHECK(condition) -#endif +#include "hwy/timer.h" // IWYU pragma: export namespace hwy { diff --git a/r/src/vendor/highway/hwy/ops/arm_neon-inl.h b/r/src/vendor/highway/hwy/ops/arm_neon-inl.h index d9a72c1c..37294205 100644 --- a/r/src/vendor/highway/hwy/ops/arm_neon-inl.h +++ b/r/src/vendor/highway/hwy/ops/arm_neon-inl.h @@ -1,5 +1,7 @@ // Copyright 2019 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates // SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,16 +23,12 @@ #include "hwy/ops/shared-inl.h" -HWY_BEFORE_NAMESPACE(); - -// Must come after HWY_BEFORE_NAMESPACE so that the intrinsics are compiled with -// the same target attribute as our code, see #834. HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") #include // NOLINT(build/include_order) HWY_DIAGNOSTICS(pop) -// Must come after arm_neon.h. +HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { @@ -143,12 +141,29 @@ namespace detail { // for code folding and Raw128 HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) -#ifdef __ARM_FEATURE_BF16_VECTOR_ARITHMETIC +// Clang 17 crashes with bf16, see github.com/llvm/llvm-project/issues/64179. +#undef HWY_NEON_HAVE_BFLOAT16 +#if HWY_HAVE_SCALAR_BF16_TYPE && \ + ((HWY_TARGET == HWY_NEON_BF16 && \ + (!HWY_COMPILER_CLANG || HWY_COMPILER_CLANG >= 1800)) || \ + defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)) #define HWY_NEON_HAVE_BFLOAT16 1 #else #define HWY_NEON_HAVE_BFLOAT16 0 #endif +// HWY_NEON_HAVE_F32_TO_BF16C is defined if NEON vcvt_bf16_f32 and +// vbfdot_f32 are available, even if the __bf16 type is disabled due to +// GCC/Clang bugs. +#undef HWY_NEON_HAVE_F32_TO_BF16C +#if HWY_NEON_HAVE_BFLOAT16 || HWY_TARGET == HWY_NEON_BF16 || \ + (defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ + (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 1100)) +#define HWY_NEON_HAVE_F32_TO_BF16C 1 +#else +#define HWY_NEON_HAVE_F32_TO_BF16C 0 +#endif + // bfloat16_t #if HWY_NEON_HAVE_BFLOAT16 #define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ @@ -160,7 +175,7 @@ namespace detail { // for code folding and Raw128 #define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) #endif -// Used for conversion instructions if HWY_NEON_HAVE_FLOAT16C. +// Used for conversion instructions if HWY_NEON_HAVE_F16C. #define HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(name, prefix, infix, \ args) \ HWY_NEON_DEF_FUNCTION(float16, 8, name, prefix##q, infix, f16, args) \ @@ -176,6 +191,33 @@ namespace detail { // for code folding and Raw128 #define HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) #endif +// Enable generic functions for whichever of (f16, bf16) are not supported. +#if !HWY_HAVE_FLOAT16 && !HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_SPECIAL_FLOAT_D(D) +#elif !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_F16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_F16_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_F16_D(D) +#elif HWY_HAVE_FLOAT16 && !HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#elif HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_BFLOAT16 +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the D template +// argument +#define HWY_NEON_IF_EMULATED_D(D) hwy::EnableIf()>* = nullptr +#define HWY_GENERIC_IF_EMULATED_D(D) \ + hwy::EnableIf()>* = nullptr +#define HWY_NEON_IF_NOT_EMULATED_D(D) hwy::EnableIf* = nullptr +#else +#error "Logic error, handled all four cases" +#endif + // float #define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ @@ -397,39 +439,6 @@ struct Tuple2 { int64x1x2_t raw; }; -template <> -struct Tuple2 { -#if HWY_NEON_HAVE_FLOAT16C - float16x8x2_t raw; -#else - uint16x8x2_t raw; -#endif -}; -template -struct Tuple2 { -#if HWY_NEON_HAVE_FLOAT16C - float16x4x2_t raw; -#else - uint16x4x2_t raw; -#endif -}; -template <> -struct Tuple2 { -#if HWY_NEON_HAVE_BFLOAT16 - bfloat16x8x2_t raw; -#else - uint16x8x2_t raw; -#endif -}; -template -struct Tuple2 { -#if HWY_NEON_HAVE_BFLOAT16 - bfloat16x4x2_t raw; -#else - uint16x4x2_t raw; -#endif -}; - template <> struct Tuple2 { float32x4x2_t raw; @@ -514,39 +523,6 @@ struct Tuple3 { int64x1x3_t raw; }; -template <> -struct Tuple3 { -#if HWY_NEON_HAVE_FLOAT16C - float16x8x3_t raw; -#else - uint16x8x3_t raw; -#endif -}; -template -struct Tuple3 { -#if HWY_NEON_HAVE_FLOAT16C - float16x4x3_t raw; -#else - uint16x4x3_t raw; -#endif -}; -template <> -struct Tuple3 { -#if HWY_NEON_HAVE_BFLOAT16 - bfloat16x8x3_t raw; -#else - uint16x8x3_t raw; -#endif -}; -template -struct Tuple3 { -#if HWY_NEON_HAVE_BFLOAT16 - bfloat16x4x3_t raw; -#else - uint16x4x3_t raw; -#endif -}; - template <> struct Tuple3 { float32x4x3_t raw; @@ -631,39 +607,6 @@ struct Tuple4 { int64x1x4_t raw; }; -template <> -struct Tuple4 { -#if HWY_NEON_HAVE_FLOAT16C - float16x8x4_t raw; -#else - uint16x8x4_t raw; -#endif -}; -template -struct Tuple4 { -#if HWY_NEON_HAVE_FLOAT16C - float16x4x4_t raw; -#else - uint16x4x4_t raw; -#endif -}; -template <> -struct Tuple4 { -#if HWY_NEON_HAVE_BFLOAT16 - bfloat16x8x4_t raw; -#else - uint16x8x4_t raw; -#endif -}; -template -struct Tuple4 { -#if HWY_NEON_HAVE_BFLOAT16 - bfloat16x4x4_t raw; -#else - uint16x4x4_t raw; -#endif -}; - template <> struct Tuple4 { float32x4x4_t raw; @@ -686,201 +629,199 @@ struct Tuple4 { template struct Raw128; -// 128 template <> struct Raw128 { using type = uint8x16_t; }; +template +struct Raw128 { + using type = uint8x8_t; +}; template <> struct Raw128 { using type = uint16x8_t; }; +template +struct Raw128 { + using type = uint16x4_t; +}; template <> struct Raw128 { using type = uint32x4_t; }; +template +struct Raw128 { + using type = uint32x2_t; +}; template <> struct Raw128 { using type = uint64x2_t; }; +template <> +struct Raw128 { + using type = uint64x1_t; +}; template <> struct Raw128 { using type = int8x16_t; }; +template +struct Raw128 { + using type = int8x8_t; +}; template <> struct Raw128 { using type = int16x8_t; }; +template +struct Raw128 { + using type = int16x4_t; +}; template <> struct Raw128 { using type = int32x4_t; }; +template +struct Raw128 { + using type = int32x2_t; +}; template <> struct Raw128 { using type = int64x2_t; }; - -template <> -struct Raw128 { -#if HWY_NEON_HAVE_FLOAT16C - using type = float16x8_t; -#else - using type = uint16x8_t; -#endif -}; - template <> -struct Raw128 { -#if HWY_NEON_HAVE_BFLOAT16 - using type = bfloat16x8_t; -#else - using type = uint16x8_t; -#endif +struct Raw128 { + using type = int64x1_t; }; template <> struct Raw128 { using type = float32x4_t; }; +template +struct Raw128 { + using type = float32x2_t; +}; #if HWY_HAVE_FLOAT64 template <> struct Raw128 { using type = float64x2_t; }; -#endif // HWY_HAVE_FLOAT64 - -// 64 template <> -struct Raw128 { - using type = uint8x8_t; -}; - -template <> -struct Raw128 { - using type = uint16x4_t; +struct Raw128 { + using type = float64x1_t; }; +#endif // HWY_HAVE_FLOAT64 -template <> -struct Raw128 { - using type = uint32x2_t; -}; +#if HWY_NEON_HAVE_F16C template <> -struct Raw128 { - using type = uint64x1_t; -}; - -template <> -struct Raw128 { - using type = int8x8_t; +struct Tuple2 { + float16x8x2_t raw; }; - -template <> -struct Raw128 { - using type = int16x4_t; +template +struct Tuple2 { + float16x4x2_t raw; }; template <> -struct Raw128 { - using type = int32x2_t; +struct Tuple3 { + float16x8x3_t raw; }; - -template <> -struct Raw128 { - using type = int64x1_t; +template +struct Tuple3 { + float16x4x3_t raw; }; template <> -struct Raw128 { -#if HWY_NEON_HAVE_FLOAT16C - using type = float16x4_t; -#else - using type = uint16x4_t; -#endif +struct Tuple4 { + float16x8x4_t raw; }; - -template <> -struct Raw128 { -#if HWY_NEON_HAVE_BFLOAT16 - using type = bfloat16x4_t; -#else - using type = uint16x4_t; -#endif +template +struct Tuple4 { + float16x4x4_t raw; }; template <> -struct Raw128 { - using type = float32x2_t; +struct Raw128 { + using type = float16x8_t; }; - -#if HWY_HAVE_FLOAT64 -template <> -struct Raw128 { - using type = float64x1_t; +template +struct Raw128 { + using type = float16x4_t; }; -#endif // HWY_HAVE_FLOAT64 - -// 32 (same as 64) -template <> -struct Raw128 : public Raw128 {}; - -template <> -struct Raw128 : public Raw128 {}; - -template <> -struct Raw128 : public Raw128 {}; -template <> -struct Raw128 : public Raw128 {}; - -template <> -struct Raw128 : public Raw128 {}; +#else // !HWY_NEON_HAVE_F16C -template <> -struct Raw128 : public Raw128 {}; - -template <> -struct Raw128 : public Raw128 {}; - -template <> -struct Raw128 : public Raw128 {}; +template +struct Tuple2 : public Tuple2 {}; +template +struct Tuple3 : public Tuple3 {}; +template +struct Tuple4 : public Tuple4 {}; +template +struct Raw128 : public Raw128 {}; -template <> -struct Raw128 : public Raw128 {}; +#endif // HWY_NEON_HAVE_F16C -// 16 (same as 64) -template <> -struct Raw128 : public Raw128 {}; +#if HWY_NEON_HAVE_BFLOAT16 template <> -struct Raw128 : public Raw128 {}; +struct Tuple2 { + bfloat16x8x2_t raw; +}; +template +struct Tuple2 { + bfloat16x4x2_t raw; +}; template <> -struct Raw128 : public Raw128 {}; +struct Tuple3 { + bfloat16x8x3_t raw; +}; +template +struct Tuple3 { + bfloat16x4x3_t raw; +}; template <> -struct Raw128 : public Raw128 {}; +struct Tuple4 { + bfloat16x8x4_t raw; +}; +template +struct Tuple4 { + bfloat16x4x4_t raw; +}; template <> -struct Raw128 : public Raw128 {}; +struct Raw128 { + using type = bfloat16x8_t; +}; +template +struct Raw128 { + using type = bfloat16x4_t; +}; -template <> -struct Raw128 : public Raw128 {}; +#else // !HWY_NEON_HAVE_BFLOAT16 -// 8 (same as 64) -template <> -struct Raw128 : public Raw128 {}; +template +struct Tuple2 : public Tuple2 {}; +template +struct Tuple3 : public Tuple3 {}; +template +struct Tuple4 : public Tuple4 {}; +template +struct Raw128 : public Raw128 {}; -template <> -struct Raw128 : public Raw128 {}; +#endif // HWY_NEON_HAVE_BFLOAT16 } // namespace detail @@ -910,6 +851,9 @@ class Vec128 { HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } @@ -935,10 +879,10 @@ using Vec16 = Vec128; // FF..FF or 0. template class Mask128 { + public: // Arm C Language Extensions return and expect unsigned type. using Raw = typename detail::Raw128, N>::type; - public: using PrivateT = T; // only for DFromM static constexpr size_t kPrivateN = N; // only for DFromM @@ -978,26 +922,22 @@ namespace detail { #define HWY_NEON_BUILD_ARG_HWY_SET t HWY_NEON_DEF_FUNCTION_ALL_TYPES(NativeSet, vdup, _n_, HWY_SET) -HWY_NEON_DEF_FUNCTION_BFLOAT_16(NativeSet, vdup, _n_, HWY_SET) -#if !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_FLOAT16C +#if !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_F16C HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(NativeSet, vdup, _n_, HWY_SET) #endif +HWY_NEON_DEF_FUNCTION_BFLOAT_16(NativeSet, vdup, _n_, HWY_SET) + +template +HWY_API Vec128, MaxLanes(D())> NativeSet(D d, TFromD t) { + const uint16_t tu = BitCastScalar(t); + return Vec128, d.MaxLanes()>(Set(RebindToUnsigned(), tu).raw); +} #undef HWY_NEON_BUILD_TPL_HWY_SET #undef HWY_NEON_BUILD_RET_HWY_SET #undef HWY_NEON_BUILD_PARAM_HWY_SET #undef HWY_NEON_BUILD_ARG_HWY_SET -#if !HWY_NEON_HAVE_BFLOAT16 -// BF16: return u16. -template -HWY_API Vec128 NativeSet(D d, bfloat16_t t) { - uint16_t tu; - CopyBytes(&t, &tu); - return Vec128(Set(RebindToUnsigned(), tu).raw); -} -#endif // !HWY_NEON_HAVE_BFLOAT16 - } // namespace detail // Full vector. Cannot yet use VFromD because that is defined in terms of Set. @@ -1033,165 +973,323 @@ HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") template HWY_API VFromD Undefined(D /*tag*/) { +#if HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + return VFromD{__builtin_nondeterministic_value(Zero(D()).raw)}; +#else VFromD v; return v; +#endif } HWY_DIAGNOSTICS(pop) +#if !HWY_COMPILER_GCC && !HWY_COMPILER_CLANGCL namespace detail { -template -HWY_INLINE VFromD Iota0(D d) { - const RebindToUnsigned du; +#pragma pack(push, 1) + +template +struct alignas(8) Vec64ValsWrapper { + static_assert(sizeof(T) >= 1, "sizeof(T) >= 1 must be true"); + static_assert(sizeof(T) <= 8, "sizeof(T) <= 8 must be true"); + T vals[8 / sizeof(T)]; +}; + +#pragma pack(pop) + +} // namespace detail +#endif // !HWY_COMPILER_GCC && !HWY_COMPILER_CLANGCL + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD /*t8*/, TFromD /*t9*/, + TFromD /*t10*/, TFromD /*t11*/, + TFromD /*t12*/, TFromD /*t13*/, + TFromD /*t14*/, TFromD /*t15*/) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef uint8_t GccU8RawVectType __attribute__((__vector_size__(8))); - constexpr GccU8RawVectType kU8Iota0 = {0, 1, 2, 3, 4, 5, 6, 7}; - const VFromD vu8_iota0(reinterpret_cast(kU8Iota0)); + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(8) static constexpr uint8_t kU8Iota0[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - const VFromD vu8_iota0( - Load(Full64>(), kU8Iota0).raw); + return ResizeBitCast( + d, Set(Full64(), + BitCastScalar(detail::Vec64ValsWrapper>{ + {t0, t1, t2, t3, t4, t5, t6, t7}}))); #endif - return BitCast(d, vu8_iota0); } -template -HWY_INLINE VFromD Iota0(D d) { - const RebindToUnsigned du; +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef uint8_t GccU8RawVectType __attribute__((__vector_size__(16))); - constexpr GccU8RawVectType kU8Iota0 = {0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15}; - const VFromD vu8_iota0(reinterpret_cast(kU8Iota0)); + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(16) static constexpr uint8_t kU8Iota0[16] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const auto vu8_iota0 = Load(du, kU8Iota0); + return ResizeBitCast( + d, Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1, t2, t3}}))); #endif - return BitCast(d, vu8_iota0); } -template -HWY_INLINE VFromD Iota0(D d) { - using T = TFromD; +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD /*t2*/, TFromD /*t3*/) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef detail::NativeLaneType GccRawVectType - __attribute__((__vector_size__(8))); - constexpr GccRawVectType kIota0 = {T{0}, T{1}, T{2}, static_cast(3)}; - return VFromD(reinterpret_cast::Raw>(kIota0)); + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI32RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(8) static constexpr T kIota0[4] = {T{0}, T{1}, T{2}, - static_cast(3)}; - return Load(d, kIota0); + return ResizeBitCast(d, + Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}))); #endif } -template -HWY_INLINE VFromD Iota0(D d) { - using T = TFromD; +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD /*t2*/, TFromD /*t3*/) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef detail::NativeLaneType GccRawVectType - __attribute__((__vector_size__(16))); - constexpr GccRawVectType kIota0 = {T{0}, T{1}, T{2}, static_cast(3), - T{4}, T{5}, T{6}, static_cast(7)}; - return VFromD(reinterpret_cast::Raw>(kIota0)); + typedef float GccF32RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccF32RawVectType raw = {t0, t1}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(16) static constexpr T kU16Iota0[8] = { - T{0}, T{1}, T{2}, static_cast(3), T{4}, T{5}, T{6}, static_cast(7)}; - return Load(d, kIota0); + return ResizeBitCast(d, + Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}))); #endif } -template -HWY_INLINE VFromD Iota0(D d) { - const RebindToUnsigned du; +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD /*t1*/) { + return Set(d, t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(8))); - constexpr GccU32RawVectType kU32Iota0 = {0, 1}; - const VFromD vu32_iota0( - reinterpret_cast(kU32Iota0)); + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), + static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(8) static constexpr uint32_t kU32Iota0[2] = {0, 1}; - const VFromD vu32_iota0{ - Load(Full64>(), kU32Iota0).raw}; + const Half dh; + return Combine(d, + Dup128VecFromValues(dh, t8, t9, t10, t11, t12, t13, t14, t15, + t8, t9, t10, t11, t12, t13, t14, t15), + Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, + t2, t3, t4, t5, t6, t7)); #endif - return BitCast(d, vu32_iota0); } -template -HWY_INLINE VFromD Iota0(D d) { - const RebindToUnsigned du; +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); - constexpr GccU32RawVectType kU32Iota0 = {0, 1, 2, 3}; - const VFromD vu32_iota0( - reinterpret_cast(kU32Iota0)); + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(16) static constexpr uint32_t kU32Iota0[4] = {0, 1, 2, 3}; - const auto vu32_iota0 = Load(du, kU32Iota0); + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t4, t5, t6, t7, t4, t5, t6, t7), + Dup128VecFromValues(dh, t0, t1, t2, t3, t0, t1, t2, t3)); #endif - return BitCast(d, vu32_iota0); } -template -HWY_INLINE VFromD Iota0(D d) { +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef float GccF32RawVectType __attribute__((__vector_size__(8))); - constexpr GccF32RawVectType kF32Iota0 = {0.0f, 1.0f}; - return VFromD(reinterpret_cast(kF32Iota0)); + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI32RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(8) static constexpr float kF32Iota0[2] = {0.0f, 1.0f}; - return VFromD{ - Load(Full64>(), kF32Iota0).raw}; + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t2, t3, t2, t3), + Dup128VecFromValues(dh, t0, t1, t0, t1)); #endif } -template -HWY_INLINE VFromD Iota0(D d) { +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL typedef float GccF32RawVectType __attribute__((__vector_size__(16))); - constexpr GccF32RawVectType kF32Iota0 = {0.0f, 1.0f, 2.0f, 3.0f}; - return VFromD(reinterpret_cast(kF32Iota0)); + (void)d; + const GccF32RawVectType raw = {t0, t1, t2, t3}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(16) static constexpr float kF32Iota0[4] = {0.0f, 1.0f, 2.0f, 3.0f}; - return Load(d, kF32Iota0); + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t2, t3, t2, t3), + Dup128VecFromValues(dh, t0, t1, t0, t1)); #endif } -template -HWY_INLINE VFromD Iota0(D d) { - return Zero(d); -} - -template -HWY_INLINE VFromD Iota0(D d) { - const RebindToUnsigned du; +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL - typedef uint64_t GccU64RawVectType __attribute__((__vector_size__(16))); - constexpr GccU64RawVectType kU64Iota0 = {0, 1}; - const VFromD vu64_iota0( - reinterpret_cast(kU64Iota0)); + typedef int64_t GccI64RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI64RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(16) static constexpr uint64_t kU64Iota0[4] = {0, 1}; - const auto vu64_iota0 = Load(du, kU64Iota0); + const Half dh; + return Combine(d, Set(dh, t1), Set(dh, t0)); #endif - return BitCast(d, vu64_iota0); } #if HWY_HAVE_FLOAT64 -template -HWY_INLINE VFromD Iota0(D d) { +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { #if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL typedef double GccF64RawVectType __attribute__((__vector_size__(16))); - constexpr GccF64RawVectType kF64Iota0 = {0.0, 1.0}; - return VFromD(reinterpret_cast(kF64Iota0)); + (void)d; + const GccF64RawVectType raw = {t0, t1}; + return VFromD(reinterpret_cast::Raw>(raw)); #else - alignas(16) static constexpr double kF64Iota0[4] = {0.0, 1.0}; - return Load(d, kF64Iota0); + const Half dh; + return Combine(d, Set(dh, t1), Set(dh, t0)); +#endif +} #endif + +// Generic for all vector lengths +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { + typedef __fp16 GccF16RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccF16RawVectType raw = { + static_cast<__fp16>(t0), static_cast<__fp16>(t1), static_cast<__fp16>(t2), + static_cast<__fp16>(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +} +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + typedef __fp16 GccF16RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF16RawVectType raw = { + static_cast<__fp16>(t0), static_cast<__fp16>(t1), static_cast<__fp16>(t2), + static_cast<__fp16>(t3), static_cast<__fp16>(t4), static_cast<__fp16>(t5), + static_cast<__fp16>(t6), static_cast<__fp16>(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +} +#else +// Generic for all vector lengths if MSVC or !HWY_NEON_HAVE_F16C +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} +#endif // (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues( + d, TFromD{0}, TFromD{1}, TFromD{2}, TFromD{3}, TFromD{4}, + TFromD{5}, TFromD{6}, TFromD{7}, TFromD{8}, TFromD{9}, + TFromD{10}, TFromD{11}, TFromD{12}, TFromD{13}, TFromD{14}, + TFromD{15}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}, TFromD{4}, TFromD{5}, + TFromD{6}, TFromD{7}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + const RebindToUnsigned du; + return BitCast(d, Dup128VecFromValues(du, uint16_t{0}, uint16_t{0x3C00}, + uint16_t{0x4000}, uint16_t{0x4200}, + uint16_t{0x4400}, uint16_t{0x4500}, + uint16_t{0x4600}, uint16_t{0x4700})); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}); } -#endif // HWY_HAVE_FLOAT64 #if HWY_COMPILER_MSVC template @@ -1226,9 +1324,6 @@ HWY_API VFromD Iota(D d, const T2 first) { #endif } -// ------------------------------ Tuple (VFromD) -#include "hwy/ops/tuple-inl.h" - // ------------------------------ Combine // Full result @@ -1274,30 +1369,25 @@ HWY_API Vec128 Combine(D /* tag */, Vec64 hi, return Vec128(vcombine_s64(lo.raw, hi.raw)); } -template -HWY_API Vec128 Combine(D d, Vec64 hi, - Vec64 lo) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API Vec128 Combine(D, Vec64 hi, Vec64 lo) { return Vec128(vcombine_f16(lo.raw, hi.raw)); -#else - const RebindToUnsigned du; - const Half duh; - return BitCast(d, Combine(du, BitCast(duh, hi), BitCast(duh, lo))); -#endif } +#endif // HWY_HAVE_FLOAT16 -template -HWY_API Vec128 Combine(D d, Vec64 hi, - Vec64 lo) { #if HWY_NEON_HAVE_BFLOAT16 - (void)d; - return Vec128(vcombine_bf16(lo.raw, hi.raw)); -#else +template +HWY_API VFromD Combine(D, Vec64 hi, Vec64 lo) { + return VFromD(vcombine_bf16(lo.raw, hi.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template , HWY_NEON_IF_EMULATED_D(D)> +HWY_API VFromD Combine(D d, VFromD hi, VFromD lo) { const RebindToUnsigned du; const Half duh; return BitCast(d, Combine(du, BitCast(duh, hi), BitCast(duh, lo))); -#endif } template @@ -1341,7 +1431,7 @@ HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) #if !HWY_HAVE_FLOAT16 -#if HWY_NEON_HAVE_FLOAT16C +#if HWY_NEON_HAVE_F16C HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) #else @@ -1349,7 +1439,7 @@ template HWY_INLINE Vec128 BitCastToByte(Vec128 v) { return BitCastToByte(Vec128(v.raw)); } -#endif // HWY_NEON_HAVE_FLOAT16C +#endif // HWY_NEON_HAVE_F16C #endif // !HWY_HAVE_FLOAT16 #if !HWY_NEON_HAVE_BFLOAT16 @@ -1406,14 +1496,24 @@ HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { return Vec64(vreinterpret_s64_u8(v.raw)); } +// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C. template -HWY_INLINE VFromD BitCastFromByte(D d, VFromD> v) { -#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_FLOAT16C - (void)d; +HWY_INLINE VFromD BitCastFromByte(D, VFromD> v) { +#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C return VFromD(vreinterpret_f16_u8(v.raw)); #else const RebindToUnsigned du; - return VFromD(BitCastFromByte(du, v).raw); + return VFromD(BitCastFromByte(du, v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D, VFromD> v) { +#if HWY_NEON_HAVE_BFLOAT16 + return VFromD(vreinterpret_bf16_u8(v.raw)); +#else + const RebindToUnsigned du; + return VFromD(BitCastFromByte(du, v).raw); #endif } @@ -1461,15 +1561,6 @@ HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { return Vec128(vreinterpretq_s64_u8(v.raw)); } -template -HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { -#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_FLOAT16C - return Vec128(vreinterpretq_f16_u8(v.raw)); -#else - return Vec128(BitCastFromByte(RebindToUnsigned(), v).raw); -#endif -} - template HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { return Vec128(vreinterpretq_f32_u8(v.raw)); @@ -1482,11 +1573,23 @@ HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { } #endif // HWY_HAVE_FLOAT64 -// Special case for bfloat16_t, which may have the same Raw as uint16_t. +// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C. +template +HWY_INLINE VFromD BitCastFromByte(D, Vec128 v) { +#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C + return VFromD(vreinterpretq_f16_u8(v.raw)); +#else + return VFromD(BitCastFromByte(RebindToUnsigned(), v).raw); +#endif +} + template -HWY_INLINE VFromD BitCastFromByte(D /* tag */, - VFromD> v) { +HWY_INLINE VFromD BitCastFromByte(D, Vec128 v) { +#if HWY_NEON_HAVE_BFLOAT16 + return VFromD(vreinterpretq_bf16_u8(v.raw)); +#else return VFromD(BitCastFromByte(RebindToUnsigned(), v).raw); +#endif } } // namespace detail @@ -1542,6 +1645,14 @@ namespace detail { #define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(GetLane, vget, _lane_, HWY_GET) + +template )> +static HWY_INLINE HWY_MAYBE_UNUSED TFromV GetLane(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCastScalar>(GetLane(BitCast(du, v))); +} #undef HWY_NEON_BUILD_TPL_HWY_GET #undef HWY_NEON_BUILD_RET_HWY_GET @@ -1688,12 +1799,21 @@ namespace detail { #define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(InsertLane, vset, _lane_, HWY_INSERT) #undef HWY_NEON_BUILD_TPL_HWY_INSERT #undef HWY_NEON_BUILD_RET_HWY_INSERT #undef HWY_NEON_BUILD_PARAM_HWY_INSERT #undef HWY_NEON_BUILD_ARG_HWY_INSERT +template , HWY_NEON_IF_EMULATED_D(D)> +HWY_API V InsertLane(const V v, TFromD t) { + const D d; + const RebindToUnsigned du; + const uint16_t tu = BitCastScalar(t); + return BitCast(d, InsertLane(BitCast(du, v), tu)); +} + } // namespace detail // Requires one overload per vector length because InsertLane<3> may be a @@ -1842,6 +1962,89 @@ HWY_API Vec128 SumsOf8(const Vec128 v) { HWY_API Vec64 SumsOf8(const Vec64 v) { return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); } +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_s32(vpaddlq_s16(vpaddlq_s8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_s32(vpaddl_s16(vpaddl_s8(v.raw)))); +} + +// ------------------------------ SumsOf2 +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u32(v.raw)); +} + +} // namespace detail // ------------------------------ SaturatedAdd @@ -1880,8 +2083,14 @@ HWY_NEON_DEF_FUNCTION_INTS_UINTS(SaturatedSub, vqsub, _, 2) // ------------------------------ Average // Returns (a + b + 1) / 2 -HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2) -HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +HWY_NEON_DEF_FUNCTION_UI_8_16_32(AverageRound, vrhadd, _, 2) // ------------------------------ Neg @@ -1922,8 +2131,39 @@ HWY_API Vec128 Neg(const Vec128 v) { #endif } +// ------------------------------ SaturatedNeg +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedNeg, vqneg, _, 1) + +#if HWY_ARCH_ARM_A64 +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +HWY_API Vec64 SaturatedNeg(const Vec64 v) { + return Vec64(vqneg_s64(v.raw)); +} + +HWY_API Vec128 SaturatedNeg(const Vec128 v) { + return Vec128(vqnegq_s64(v.raw)); +} +#endif + // ------------------------------ ShiftLeft +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + // Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). #pragma push_macro("HWY_NEON_DEF_FUNCTION") #undef HWY_NEON_DEF_FUNCTION @@ -1939,16 +2179,22 @@ HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_UINTS(RoundingShiftRight, vrshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(RoundingShiftRight, vrshr, _n_, ignored) #pragma pop_macro("HWY_NEON_DEF_FUNCTION") // ------------------------------ RotateRight (ShiftRight, Or) -template +template HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; - return Or(ShiftRight(v), + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), ShiftLeft(v)); } @@ -2111,6 +2357,95 @@ HWY_API Vec64 operator>>(Vec64 v, Vec64 bits) { return Vec64(vshl_s64(v.raw, Neg(bits).raw)); } +// ------------------------------ RoundingShr (Neg) + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int8x16_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int8x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int64x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 RoundingShr(Vec64 v, Vec64 bits) { + const RebindToSigned> di; + const int64x1_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec64(vrshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 RoundingShr(Vec64 v, Vec64 bits) { + return Vec64(vrshl_s64(v.raw, Neg(bits).raw)); +} + // ------------------------------ ShiftLeftSame (Shl) template @@ -2122,6 +2457,13 @@ HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { return v >> Set(DFromV(), static_cast(bits)); } +// ------------------------------ RoundingShiftRightSame (RoundingShr) + +template +HWY_API Vec128 RoundingShiftRightSame(const Vec128 v, int bits) { + return RoundingShr(v, Set(DFromV(), static_cast(bits))); +} + // ------------------------------ Int/float multiplication // Per-target flag to prevent generic_ops-inl.h from defining 8-bit operator*. @@ -2138,7 +2480,39 @@ HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) // ------------------------------ Integer multiplication -// Returns the upper 16 bits of a * b in each lane. +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int16x8_t rlo = vmull_s8(vget_low_s8(a.raw), vget_low_s8(b.raw)); +#if HWY_ARCH_ARM_A64 + int16x8_t rhi = vmull_high_s8(a.raw, b.raw); +#else + int16x8_t rhi = vmull_s8(vget_high_s8(a.raw), vget_high_s8(b.raw)); +#endif + return Vec128( + vuzp2q_s8(vreinterpretq_s8_s16(rlo), vreinterpretq_s8_s16(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint16x8_t rlo = vmull_u8(vget_low_u8(a.raw), vget_low_u8(b.raw)); +#if HWY_ARCH_ARM_A64 + uint16x8_t rhi = vmull_high_u8(a.raw, b.raw); +#else + uint16x8_t rhi = vmull_u8(vget_high_u8(a.raw), vget_high_u8(b.raw)); +#endif + return Vec128( + vuzp2q_u8(vreinterpretq_u8_u16(rlo), vreinterpretq_u8_u16(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int8x16_t hi_lo = vreinterpretq_s8_s16(vmull_s8(a.raw, b.raw)); + return Vec128(vget_low_s8(vuzp2q_s8(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint8x16_t hi_lo = vreinterpretq_u8_u16(vmull_u8(a.raw, b.raw)); + return Vec128(vget_low_u8(vuzp2q_u8(hi_lo, hi_lo))); +} + HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); #if HWY_ARCH_ARM_A64 @@ -2172,6 +2546,57 @@ HWY_API Vec128 MulHigh(Vec128 a, return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); } +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int64x2_t rlo = vmull_s32(vget_low_s32(a.raw), vget_low_s32(b.raw)); +#if HWY_ARCH_ARM_A64 + int64x2_t rhi = vmull_high_s32(a.raw, b.raw); +#else + int64x2_t rhi = vmull_s32(vget_high_s32(a.raw), vget_high_s32(b.raw)); +#endif + return Vec128( + vuzp2q_s32(vreinterpretq_s32_s64(rlo), vreinterpretq_s32_s64(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint64x2_t rlo = vmull_u32(vget_low_u32(a.raw), vget_low_u32(b.raw)); +#if HWY_ARCH_ARM_A64 + uint64x2_t rhi = vmull_high_u32(a.raw, b.raw); +#else + uint64x2_t rhi = vmull_u32(vget_high_u32(a.raw), vget_high_u32(b.raw)); +#endif + return Vec128( + vuzp2q_u32(vreinterpretq_u32_u64(rlo), vreinterpretq_u32_u64(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int32x4_t hi_lo = vreinterpretq_s32_s64(vmull_s32(a.raw, b.raw)); + return Vec128(vget_low_s32(vuzp2q_s32(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, + Vec128 b) { + uint32x4_t hi_lo = vreinterpretq_u32_u64(vmull_u32(a.raw, b.raw)); + return Vec128(vget_low_u32(vuzp2q_u32(hi_lo, hi_lo))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(detail::GetLane<1>(a), detail::GetLane<1>(b), &hi_1); + + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { return Vec128(vqrdmulhq_s16(a.raw, b.raw)); } @@ -2184,7 +2609,7 @@ HWY_API Vec128 MulFixedPoint15(Vec128 a, // ------------------------------ Floating-point division // Emulate missing intrinsic -#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +#if HWY_HAVE_FLOAT64 && HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 HWY_INLINE float64x1_t vrecpe_f64(float64x1_t raw) { const CappedTag d; const Twice dt; @@ -2277,7 +2702,7 @@ HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, namespace detail { -#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#if HWY_NATIVE_FMA // Wrappers for changing argument order to what intrinsics expect. HWY_NEON_DEF_FUNCTION_ALL_FLOATS(MulAdd, vfma, _, 3) HWY_NEON_DEF_FUNCTION_ALL_FLOATS(NegMulAdd, vfms, _, 3) @@ -2295,7 +2720,7 @@ HWY_API Vec128 NegMulAdd(Vec128 add, Vec128 mul, return add - mul * x; } -#endif // defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#endif // HWY_NATIVE_FMA } // namespace detail template @@ -2310,13 +2735,13 @@ HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, return detail::NegMulAdd(add, mul, x); } -template +template HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, Vec128 sub) { return MulAdd(mul, x, Neg(sub)); } -template +template HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, Vec128 sub) { return Neg(MulAdd(mul, x, sub)); @@ -2612,6 +3037,15 @@ HWY_API Vec128 PopulationCount(Vec128 v) { HWY_NEON_DEF_FUNCTION_INT_8_16_32(Abs, vabs, _, 1) HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Abs, vabs, _, 1) +// ------------------------------ SaturatedAbs +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedAbs, vqabs, _, 1) + // ------------------------------ CopySign template HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { @@ -2675,22 +3109,49 @@ HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_BF16(TFromV) +#else +#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_SPECIAL_FLOAT_V(V) +#endif + +template +HWY_API V IfThenElse(MFromD> mask, V yes, V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#undef HWY_NEON_IF_EMULATED_IF_THEN_ELSE #undef HWY_NEON_BUILD_TPL_HWY_IF #undef HWY_NEON_BUILD_RET_HWY_IF #undef HWY_NEON_BUILD_PARAM_HWY_IF #undef HWY_NEON_BUILD_ARG_HWY_IF // mask ? yes : 0 -template +template HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { return yes & VecFromMask(DFromV(), mask); } +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} // mask ? 0 : no -template +template HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { return AndNot(VecFromMask(DFromV(), mask), no); } +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} template HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, @@ -2703,12 +3164,6 @@ HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, return IfThenElse(m, yes, no); } -template -HWY_API Vec128 ZeroIfNegative(Vec128 v) { - const auto zero = Zero(DFromV()); - return Max(zero, v); -} - // ------------------------------ Mask logical template @@ -2957,6 +3412,23 @@ HWY_API Vec64 Abs(const Vec64 v) { #endif } +HWY_API Vec128 SaturatedAbs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vqabsq_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), SaturatedSub(zero, v), v); +#endif +} +HWY_API Vec64 SaturatedAbs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vqabs_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), SaturatedSub(zero, v), v); +#endif +} + // ------------------------------ Min (IfThenElse, BroadcastSignBit) // Unsigned @@ -3133,6 +3605,20 @@ HWY_API Vec128 LoadU(D /* tag */, const int64_t* HWY_RESTRICT unaligned) { return Vec128(vld1q_s64(unaligned)); } +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, + const float16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f16(detail::NativeLanePointer(unaligned))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, + const bfloat16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_bf16(detail::NativeLanePointer(unaligned))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT unaligned) { return Vec128(vld1q_f32(unaligned)); @@ -3179,6 +3665,18 @@ template HWY_API Vec64 LoadU(D /* tag */, const int64_t* HWY_RESTRICT p) { return Vec64(vld1_s64(p)); } +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec64(vld1_f16(detail::NativeLanePointer(p))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const bfloat16_t* HWY_RESTRICT p) { + return Vec64(vld1_bf16(detail::NativeLanePointer(p))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API Vec64 LoadU(D /* tag */, const float* HWY_RESTRICT p) { return Vec64(vld1_f32(p)); @@ -3207,14 +3705,34 @@ HWY_API Vec32 LoadU(D /*tag*/, const float* HWY_RESTRICT p) { return Vec32(vld1_dup_f32(p)); } -template +// {u,i}{8,16} +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { const Repartition d32; uint32_t buf; CopyBytes<4>(p, &buf); return BitCast(d, LoadU(d32, &buf)); } +#endif // HWY_NEON_HAVE_BFLOAT16 // ------------------------------ Load 16 @@ -3228,6 +3746,18 @@ template HWY_API VFromD LoadU(D /* tag */, const int16_t* HWY_RESTRICT p) { return VFromD(vld1_dup_s16(p)); } +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_f16(detail::NativeLanePointer(p))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD LoadU(D /* tag */, const bfloat16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_bf16(detail::NativeLanePointer(p))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 // 8-bit x2 template @@ -3250,12 +3780,10 @@ HWY_API VFromD LoadU(D /* tag */, const int8_t* HWY_RESTRICT p) { // ------------------------------ Load misc -// [b]float16_t may use the same Raw as uint16_t, so forward to that. -template +template HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { - const RebindToUnsigned du16; - const auto pu16 = reinterpret_cast(p); - return BitCast(d, LoadU(du16, pu16)); + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); } // On Arm, Load is the same as LoadU. @@ -3324,6 +3852,20 @@ HWY_API void StoreU(Vec128 v, D /* tag */, int64_t* HWY_RESTRICT unaligned) { vst1q_s64(unaligned, v.raw); } +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + float16_t* HWY_RESTRICT unaligned) { + vst1q_f16(detail::NativeLanePointer(unaligned), v.raw); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + bfloat16_t* HWY_RESTRICT unaligned) { + vst1q_bf16(detail::NativeLanePointer(unaligned), v.raw); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API void StoreU(Vec128 v, D /* tag */, float* HWY_RESTRICT unaligned) { @@ -3371,6 +3913,20 @@ template HWY_API void StoreU(Vec64 v, D /* tag */, int64_t* HWY_RESTRICT p) { vst1_s64(p, v.raw); } +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + vst1_f16(detail::NativeLanePointer(p), v.raw); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, + bfloat16_t* HWY_RESTRICT p) { + vst1_bf16(detail::NativeLanePointer(p), v.raw); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API void StoreU(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { vst1_f32(p, v.raw); @@ -3397,28 +3953,31 @@ HWY_API void StoreU(Vec32 v, D, float* HWY_RESTRICT p) { vst1_lane_f32(p, v.raw, 0); } -// Overload 16-bit types directly to avoid ambiguity with [b]float16_t. -template , - HWY_IF_T_SIZE(T, 1)> -HWY_API void StoreU(Vec32 v, D d, T* HWY_RESTRICT p) { +// {u,i}{8,16} +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { Repartition d32; uint32_t buf = GetLane(BitCast(d32, v)); CopyBytes<4>(&buf, p); } -template -HWY_API void StoreU(Vec32 v, D d, uint16_t* HWY_RESTRICT p) { +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { Repartition d32; uint32_t buf = GetLane(BitCast(d32, v)); CopyBytes<4>(&buf, p); } - -template -HWY_API void StoreU(Vec32 v, D d, int16_t* HWY_RESTRICT p) { +#endif +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { Repartition d32; uint32_t buf = GetLane(BitCast(d32, v)); CopyBytes<4>(&buf, p); } +#endif // HWY_NEON_HAVE_BFLOAT16 // ------------------------------ Store 16 @@ -3430,6 +3989,18 @@ template HWY_API void StoreU(Vec16 v, D, int16_t* HWY_RESTRICT p) { vst1_lane_s16(p, v.raw, 0); } +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec16 v, D, float16_t* HWY_RESTRICT p) { + vst1_lane_f16(detail::NativeLanePointer(p), v.raw, 0); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec16 v, D, bfloat16_t* HWY_RESTRICT p) { + vst1_lane_bf16(detail::NativeLanePointer(p), v.raw, 0); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { @@ -3449,12 +4020,12 @@ HWY_API void StoreU(Vec128 v, D, int8_t* HWY_RESTRICT p) { vst1_lane_s8(p, v.raw, 0); } -// [b]float16_t may use the same Raw as uint16_t, so forward to that. -template +// ------------------------------ Store misc + +template HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { - const RebindToUnsigned du16; - const auto pu16 = reinterpret_cast(p); - return StoreU(BitCast(du16, v), du16, pu16); + const RebindToUnsigned du; + return StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); } HWY_DIAGNOSTICS(push) @@ -3541,24 +4112,6 @@ HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { return VFromD(vcvt_f32_u32(v.raw)); } -// Truncates (rounds toward zero). -template -HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { - return Vec128(vcvtq_s32_f32(v.raw)); -} -template -HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { - return VFromD(vcvt_s32_f32(v.raw)); -} -template -HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { - return Vec128(vcvtq_u32_f32(ZeroIfNegative(v).raw)); -} -template -HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { - return VFromD(vcvt_u32_f32(ZeroIfNegative(v).raw)); -} - #if HWY_HAVE_FLOAT64 template @@ -3577,51 +4130,168 @@ HWY_API Vec64 ConvertTo(D /* tag */, Vec64 v) { template HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { - return Vec128(vcvtq_f64_u64(ZeroIfNegative(v).raw)); + return Vec128(vcvtq_f64_u64(v.raw)); } template HWY_API Vec64 ConvertTo(D /* tag */, Vec64 v) { // GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. - const auto non_neg_v = ZeroIfNegative(v); #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 - return Set(Full64(), static_cast(GetLane(non_neg_v))); + return Set(Full64(), static_cast(GetLane(v))); #else - return Vec64(vcvt_f64_u64(non_neg_v.raw)); + return Vec64(vcvt_f64_u64(v.raw)); #endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 } +#endif // HWY_HAVE_FLOAT64 + +namespace detail { // Truncates (rounds toward zero). -template -HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an int32_t. + + int32x4_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzs %0.4s, %1.4s" +#else + "vcvt.s32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s32_f32(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToI(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an int32_t. + + int32x2_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzs %0.2s, %1.2s" +#else + "vcvt.s32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_s32_f32(v.raw)); +#endif +} +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an uint32_t. + + uint32x4_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzu %0.4s, %1.4s" +#else + "vcvt.u32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u32_f32(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToU(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an uint32_t. + + uint32x2_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzu %0.2s, %1.2s" +#else + "vcvt.u32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_u32_f32(v.raw)); +#endif +} + +#if HWY_HAVE_FLOAT64 + +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int64_t. + int64x2_t raw_result; + __asm__("fcvtzs %0.2d, %1.2d" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else return Vec128(vcvtq_s64_f64(v.raw)); +#endif } -template -HWY_API Vec64 ConvertTo(D di, Vec64 v) { - // GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. Use the - // 128-bit version to avoid UB from casting double -> int64_t. -#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 - const Full128 ddt; - const Twice dit; - return LowerHalf(di, ConvertTo(dit, Combine(ddt, v, v))); +template +HWY_INLINE Vec64 ConvertFToI(D /* tag */, Vec64 v) { +#if HWY_ARCH_ARM_A64 && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200)) + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int64_t. + // If compiling for AArch64 NEON with GCC 6 or earlier, use inline assembly to + // work around the missing vcvt_s64_f64 intrinsic. + int64x1_t raw_result; + __asm__("fcvtzs %d0, %d1" : "=w"(raw_result) : "w"(v.raw)); + return Vec64(raw_result); #else - (void)di; return Vec64(vcvt_s64_f64(v.raw)); #endif } -template -HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint64_t. + uint64x2_t raw_result; + __asm__("fcvtzu %0.2d, %1.2d" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else return Vec128(vcvtq_u64_f64(v.raw)); +#endif } -template -HWY_API Vec64 ConvertTo(D du, Vec64 v) { - // GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. Use the - // 128-bit version to avoid UB from casting double -> uint64_t. -#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 - const Full128 ddt; - const Twice du_t; - return LowerHalf(du, ConvertTo(du_t, Combine(ddt, v, v))); +template +HWY_INLINE Vec64 ConvertFToU(D /* tag */, Vec64 v) { +#if HWY_ARCH_ARM_A64 && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200)) + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint64_t. + + // Inline assembly is also used if compiling for AArch64 NEON with GCC 6 or + // earlier to work around the issue of the missing vcvt_u64_f64 intrinsic. + uint64x1_t raw_result; + __asm__("fcvtzu %d0, %d1" : "=w"(raw_result) : "w"(v.raw)); + return Vec64(raw_result); #else - (void)du; return Vec64(vcvt_u64_f64(v.raw)); #endif } @@ -3631,25 +4301,76 @@ HWY_API Vec64 ConvertTo(D du, Vec64 v) { #if HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 // Truncates (rounds toward zero). -template -HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int16_t. + int16x8_t raw_result; + __asm__("fcvtzs %0.8h, %1.8h" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else return Vec128(vcvtq_s16_f16(v.raw)); +#endif } template -HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { +HWY_INLINE VFromD ConvertFToI(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int16_t. + int16x4_t raw_result; + __asm__("fcvtzs %0.4h, %1.4h" : "=w"(raw_result) : "w"(v.raw)); + return VFromD(raw_result); +#else return VFromD(vcvt_s16_f16(v.raw)); +#endif } -template -HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint16_t. + uint16x8_t raw_result; + __asm__("fcvtzu %0.8h, %1.8h" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else return Vec128(vcvtq_u16_f16(v.raw)); +#endif } template -HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { +HWY_INLINE VFromD ConvertFToU(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint16_t. + uint16x4_t raw_result; + __asm__("fcvtzu %0.4h, %1.4h" : "=w"(raw_result) : "w"(v.raw)); + return VFromD(raw_result); +#else return VFromD(vcvt_u16_f16(v.raw)); +#endif } #endif // HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 +} // namespace detail + +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { + return detail::ConvertFToI(di, v); +} + +template +HWY_API VFromD ConvertTo(D du, VFromD> v) { + return detail::ConvertFToU(du, v); +} // ------------------------------ PromoteTo (ConvertTo) @@ -3782,7 +4503,7 @@ HWY_API VFromD PromoteTo(D d, V v) { return PromoteTo(d, PromoteTo(di32, v)); } -#if HWY_NEON_HAVE_FLOAT16C +#if HWY_NEON_HAVE_F16C // Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. #ifdef HWY_NATIVE_F16C @@ -3800,7 +4521,7 @@ HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { return VFromD(vget_low_f32(vcvt_f32_f16(v.raw))); } -#endif // HWY_NEON_HAVE_FLOAT16C +#endif // HWY_NEON_HAVE_F16C #if HWY_HAVE_FLOAT64 @@ -3893,8 +4614,36 @@ HWY_API VFromD PromoteTo(D du64, VFromD> v) { lo32_or_mask); } +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(D d64, VFromD> v) { + const Rebind>, decltype(d64)> d32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + constexpr uint32_t kExpAdjDecr = + 0xFFFFFF9Du + static_cast(!IsSigned>()); + + const auto exponent_adj = BitCast( + du32, SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, kExpAdjDecr)))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + return PromoteTo(d64, ConvertTo(d32, adj_v)) << PromoteTo(d64, exponent_adj); +} + #endif // HWY_HAVE_FLOAT64 +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + // ------------------------------ PromoteUpperTo #if HWY_ARCH_ARM_A64 @@ -3946,14 +4695,14 @@ HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128(vmovl_high_s32(v.raw)); } -#if HWY_NEON_HAVE_FLOAT16C +#if HWY_NEON_HAVE_F16C template HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { return Vec128(vcvt_high_f32_f16(v.raw)); } -#endif // HWY_NEON_HAVE_FLOAT16C +#endif // HWY_NEON_HAVE_F16C template HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { @@ -4149,7 +4898,7 @@ HWY_API VFromD DemoteTo(D d, Vec64 v) { return DemoteTo(d, DemoteTo(du32, v)); } -#if HWY_NEON_HAVE_FLOAT16C +#if HWY_NEON_HAVE_F16C // We already toggled HWY_NATIVE_F16C above. @@ -4162,16 +4911,47 @@ HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { return VFromD(vcvt_f16_f32(vcombine_f32(v.raw, v.raw))); } -#endif // HWY_NEON_HAVE_FLOAT16C +#endif // HWY_NEON_HAVE_F16C -template -HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { - const Rebind di32; - const Rebind du32; // for logical shift right - const Rebind du16; - const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); - return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +#if HWY_NEON_HAVE_F32_TO_BF16C +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { +#if HWY_NEON_HAVE_BFLOAT16 +// If HWY_NEON_HAVE_BFLOAT16 is true, detail::Vec128::type is +// bfloat16x4_t or bfloat16x8_t. +static HWY_INLINE bfloat16x4_t BitCastFromRawNeonBF16(bfloat16x4_t raw) { + return raw; +} +#else +// If HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true, +// detail::Vec128::type is uint16x4_t or uint16x8_t vector to +// work around compiler bugs that are there with GCC 13 or earlier or Clang 16 +// or earlier on AArch64. + +// The bfloat16x4_t vector returned by vcvt_bf16_f32 needs to be bitcasted to +// an uint16x4_t vector if HWY_NEON_HAVE_F32_TO_BF16C && +// !HWY_NEON_HAVE_BFLOAT16 is true. +static HWY_INLINE uint16x4_t BitCastFromRawNeonBF16(bfloat16x4_t raw) { + return vreinterpret_u16_bf16(raw); } +#endif +} // namespace detail + +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { + return VFromD(detail::BitCastFromRawNeonBF16(vcvt_bf16_f32(v.raw))); +} +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { + return VFromD(detail::BitCastFromRawNeonBF16( + vcvt_bf16_f32(vcombine_f32(v.raw, v.raw)))); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C #if HWY_HAVE_FLOAT64 @@ -4184,32 +4964,10 @@ HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); } -template -HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { - const int64x2_t i64 = vcvtq_s64_f64(v.raw); - return Vec64(vqmovn_s64(i64)); -} -template -HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { - // There is no i64x1 -> i32x1 narrow, so Combine to 128-bit. Do so with the - // f64 input already to also avoid the missing vcvt_s64_f64 in GCC 6.4. - const Full128 ddt; - const Full128 dit; - return Vec32(vqmovn_s64(ConvertTo(dit, Combine(ddt, v, v)).raw)); -} - -template -HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { - const uint64x2_t u64 = vcvtq_u64_f64(v.raw); - return Vec64(vqmovn_u64(u64)); -} -template -HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { - // There is no i64x1 -> i32x1 narrow, so Combine to 128-bit. Do so with the - // f64 input already to also avoid the missing vcvt_s64_f64 in GCC 6.4. - const Full128 ddt; - const Full128 du_t; - return Vec32(vqmovn_u64(ConvertTo(du_t, Combine(ddt, v, v)).raw)); +template +HWY_API VFromD DemoteTo(D d32, VFromD> v) { + const Rebind>, D> d64; + return DemoteTo(d32, ConvertTo(d64, v)); } #endif // HWY_HAVE_FLOAT64 @@ -4438,8 +5196,101 @@ HWY_API Vec128 Floor(const Vec128 v) { #endif +// ------------------------------ CeilInt/FloorInt +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s16_f16(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtp_s16_f16(v.raw)); +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s16_f16(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtm_s16_f16(v.raw)); +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s32_f32(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtp_s32_f32(v.raw)); +} + +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtp_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, CeilInt(Combine(dt, v, v))); +#else + return Vec128(vcvtp_s64_f64(v.raw)); +#endif +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s32_f32(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtm_s32_f32(v.raw)); +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtm_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, FloorInt(Combine(dt, v, v))); +#else + return Vec128(vcvtm_s64_f64(v.raw)); +#endif +} + +#endif // HWY_ARCH_ARM_A64 + // ------------------------------ NearestInt (Round) +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s16_f16(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s16_f16(v.raw)); +} +#endif + #if HWY_ARCH_ARM_A64 HWY_API Vec128 NearestInt(const Vec128 v) { @@ -4450,6 +5301,29 @@ HWY_API Vec128 NearestInt(const Vec128 v) { return Vec128(vcvtn_s32_f32(v.raw)); } +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtn_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, NearestInt(Combine(dt, v, v))); +#else + return Vec128(vcvtn_s64_f64(v.raw)); +#endif +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, NearestInt(v)); +} + #else template @@ -4461,34 +5335,62 @@ HWY_API Vec128 NearestInt(const Vec128 v) { #endif // ------------------------------ Floating-point classification + +#if !HWY_COMPILER_CLANG || HWY_COMPILER_CLANG > 1801 || HWY_ARCH_ARM_V7 template HWY_API Mask128 IsNaN(const Vec128 v) { return v != v; } +#else +// Clang up to 18.1 generates less efficient code than the expected FCMEQ, see +// https://github.com/numpy/numpy/issues/27313 and +// https://github.com/numpy/numpy/pull/22954/files and +// https://github.com/llvm/llvm-project/issues/59855 -template -HWY_API Mask128 IsInf(const Vec128 v) { - const DFromV d; - const RebindToSigned di; - const VFromD vi = BitCast(di, v); - // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.8h, %1.8h, %1.8h" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.4h, %1.4h, %1.4h" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); } +#endif // HWY_HAVE_FLOAT16 -// Returns whether normal/subnormal/zero. -template -HWY_API Mask128 IsFinite(const Vec128 v) { - const DFromV d; - const RebindToUnsigned du; - const RebindToSigned di; // cheaper than unsigned comparison - const VFromD vu = BitCast(du, v); - // 'Shift left' to clear the sign bit, then right so we can compare with the - // max exponent (cannot compare with MaxExponentTimes2 directly because it is - // negative and non-negative floats would be greater). - const VFromD exp = - BitCast(di, ShiftRight() + 1>(Add(vu, vu))); - return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.4s, %1.4s, %1.4s" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.2s, %1.2s, %1.2s" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.2d, %1.2d, %1.2d" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %d0, %d1, %d1" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); } +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_COMPILER_CLANG // ================================================== SWIZZLE @@ -4532,13 +5434,18 @@ HWY_API Vec64 LowerHalf(Vec128 v) { return Vec64(vget_low_f16(v.raw)); } #endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_bf16(v.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 #if HWY_HAVE_FLOAT64 HWY_API Vec64 LowerHalf(Vec128 v) { return Vec64(vget_low_f64(v.raw)); } #endif // HWY_HAVE_FLOAT64 -template +template ), HWY_IF_V_SIZE_V(V, 16)> HWY_API VFromD>> LowerHalf(V v) { const Full128 du; const Half> dh; @@ -4738,6 +5645,12 @@ HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { return Vec64(vget_high_f16(v.raw)); } #endif +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_bf16(v.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { return Vec64(vget_high_f32(v.raw)); @@ -4749,7 +5662,7 @@ HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { } #endif // HWY_HAVE_FLOAT64 -template +template HWY_API VFromD UpperHalf(D dh, VFromD> v) { const RebindToUnsigned> du; const Half duh; @@ -4869,6 +5782,20 @@ HWY_API Vec128 Broadcast(Vec128 v) { } #endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_bf16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_bf16(v.raw, kLane)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + template HWY_API Vec128 Broadcast(Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); @@ -4976,7 +5903,26 @@ HWY_API Vec128 Broadcast(Vec128 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); return Vec128(vdupq_n_f16(vgetq_lane_f16(v.raw, kLane))); } +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f16(v.raw, kLane)); +} #endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_bf16(vgetq_lane_bf16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_bf16(v.raw, kLane)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 template HWY_API Vec128 Broadcast(Vec128 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); @@ -4991,6 +5937,14 @@ HWY_API Vec128 Broadcast(Vec128 v) { #endif // HWY_ARCH_ARM_A64 +template ), + HWY_IF_LANES_GT_D(DFromV, 1)> +HWY_API V Broadcast(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Broadcast(BitCast(du, v))); +} + // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. @@ -5393,6 +6347,16 @@ HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { } #endif +#if !HWY_HAVE_FLOAT16 +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, InterleaveLower(BitCast(du, a), BitCast(du, b))); +} +#endif // !HWY_HAVE_FLOAT16 + // < 64 bit parts template HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { @@ -5717,117 +6681,615 @@ HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { (void)d; #endif - return detail::SlideDownLanes(v, amt); + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ------------------------------- WidenHighMulAdd + +#ifdef HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#undef HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#else +#define HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#endif + +namespace detail { + +template, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u32(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u32(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_LE_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulResult = Vec128(vmull_u32(mul.raw, x.raw)); + return UpperHalf(d, mulResult) + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s32(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s32(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_LE_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulResult = Vec128(vmull_s32(mul.raw, x.raw)); + return UpperHalf(d, mulResult) + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s16(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s16(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s16(mul.raw, x.raw)); + Vec64 hi = UpperHalf(d, widen); + return hi + add; +} + +template, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s16(mul.raw, x.raw)); + Vec32 hi = UpperHalf(d, Vec64(vget_high_s32(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u16(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u16(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u16(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u16(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, Vec64(vget_high_u32(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u8(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u8(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u8(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template), class DN = RepartitionToNarrow, + HWY_IF_LANES_LE_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u8(mul.raw, x.raw)); + const Twice d16F; + VFromD hi = UpperHalf(d, VFromD(vget_high_u16(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s8(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s8(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s8(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template, + HWY_IF_LANES_LE_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s8(mul.raw, x.raw)); + const Twice d16F; + VFromD hi = UpperHalf(d, VFromD(vget_high_s16(widen.raw))); + return hi + add; +} + +#if 0 +#if HWY_HAVE_FLOAT16 +template> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vfmlalq_high_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec64(vfmlal_high_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + return MulAdd(add, PromoteUpperTo(d, mul), PromoteUpperTo(d, x)); +} +#endif +#endif + +} // namespace detail + +// ------------------------------- WidenMulAdd + +#ifdef HWY_NATIVE_WIDEN_MUL_ADD +#undef HWY_NATIVE_WIDEN_MUL_ADD +#else +#define HWY_NATIVE_WIDEN_MUL_ADD +#endif + +namespace detail { + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_u8(add.raw, mul.raw, x.raw)); +} + +template >, D>> +HWY_API VFromD WidenMulAdd(D d, VFromD mul, VFromD x, + VFromD add) { + return MulAdd(add, PromoteTo(d, mul), PromoteTo(d, x)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_s8(add.raw, mul.raw, x.raw)); +} + +template >, D>> +HWY_API VFromD WidenMulAdd(D d, VFromD mul, VFromD x, + VFromD add) { + return MulAdd(add, PromoteTo(d, mul), PromoteTo(d, x)); +} + +template>, D>, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_s16(add.raw, mul.raw, x.raw)); +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_s16(mul.raw, x.raw)); + const VFromD mul10 = LowerHalf(mulRs); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(D, 1)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec64 mulRs = LowerHalf(Vec128(vmull_s16(mul.raw, x.raw))); + const Vec32 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_u16(add.raw, mul.raw, x.raw)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_u16(mul.raw, x.raw)); + const Vec64 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec64 mulRs = + LowerHalf(Vec128(vmull_u16(mul.raw, x.raw))); + const Vec32 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_s32(add.raw, mul.raw, x.raw)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_s32(mul.raw, x.raw)); + const VFromD mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_u32(add.raw, mul.raw, x.raw)); +} + +template>, D>, + HWY_IF_LANES_D(DN, 1)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_u32(mul.raw, x.raw)); + const VFromD mul10(LowerHalf(mulRs)); + return add + mul10; +} + +#if 0 +#if HWY_HAVE_FLOAT16 +template, + HWY_IF_LANES_D(D, 4)> +HWY_API VFromD WidenLowMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vfmlalq_low_f16(add.raw, mul.raw, x.raw)); +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenLowMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec64(vfmlal_low_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenLowMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + return MulAdd(add, PromoteLowerTo(d, mul), PromoteLowerTo(d, x)); } +#endif +#endif -template -HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { -#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang - if (__builtin_constant_p(amt)) { - switch (amt) { - case 0: - return v; - case 1: - return ShiftRightLanes<1>(d, v); - case 2: - return ShiftRightLanes<2>(d, v); - case 3: - return ShiftRightLanes<3>(d, v); - } - } +} // namespace detail + +// ------------------------------ WidenMulAccumulate + +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE #else - (void)d; +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE #endif - return detail::SlideDownLanes(v, amt); +template), class DN = RepartitionToNarrow> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = detail::WidenHighMulAdd(d, mul, x, high); + return detail::WidenMulAdd(d, LowerHalf(mul), LowerHalf(x), low); } -template -HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { -#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang - if (__builtin_constant_p(amt)) { - switch (amt) { - case 0: - return v; - case 1: - return ShiftRightLanes<1>(d, v); - case 2: - return ShiftRightLanes<2>(d, v); - case 3: - return ShiftRightLanes<3>(d, v); - case 4: - return ShiftRightLanes<4>(d, v); - case 5: - return ShiftRightLanes<5>(d, v); - case 6: - return ShiftRightLanes<6>(d, v); - case 7: - return ShiftRightLanes<7>(d, v); - } - } +#if 0 +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 #else - (void)d; +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 #endif - return detail::SlideDownLanes(v, amt); +#if HWY_HAVE_FLOAT16 + +template> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = detail::WidenHighMulAdd(d, mul, x, high); + return detail::WidenLowMulAdd(d, mul, x, low); } -template -HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { -#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang - if (__builtin_constant_p(amt)) { - switch (amt) { - case 0: - return v; - case 1: - return ShiftRightLanes<1>(d, v); - case 2: - return ShiftRightLanes<2>(d, v); - case 3: - return ShiftRightLanes<3>(d, v); - case 4: - return ShiftRightLanes<4>(d, v); - case 5: - return ShiftRightLanes<5>(d, v); - case 6: - return ShiftRightLanes<6>(d, v); - case 7: - return ShiftRightLanes<7>(d, v); - case 8: - return ShiftRightLanes<8>(d, v); - case 9: - return ShiftRightLanes<9>(d, v); - case 10: - return ShiftRightLanes<10>(d, v); - case 11: - return ShiftRightLanes<11>(d, v); - case 12: - return ShiftRightLanes<12>(d, v); - case 13: - return ShiftRightLanes<13>(d, v); - case 14: - return ShiftRightLanes<14>(d, v); - case 15: - return ShiftRightLanes<15>(d, v); - } - } +#endif +#endif + +// ------------------------------ SatWidenMulAccumFixedPoint + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT #else - (void)d; +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT #endif - return detail::SlideDownLanes(v, amt); +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vqdmlal_s16(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Full128> di32_full; + const Rebind di16_full64; + return ResizeBitCast( + di32, SatWidenMulAccumFixedPoint(di32_full, ResizeBitCast(di16_full64, a), + ResizeBitCast(di16_full64, b), + ResizeBitCast(di32_full, sum))); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) +#if HWY_NEON_HAVE_F32_TO_BF16C + +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +namespace detail { #if HWY_NEON_HAVE_BFLOAT16 +// If HWY_NEON_HAVE_BFLOAT16 is true, detail::Vec128::type is +// bfloat16x4_t or bfloat16x8_t. +static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(bfloat16x4_t raw) { + return raw; +} +static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(bfloat16x8_t raw) { + return raw; +} +#else +// If HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true, +// detail::Vec128::type is uint16x4_t or uint16x8_t vector to +// work around compiler bugs that are there with GCC 13 or earlier or Clang 16 +// or earlier on AArch64. + +// The uint16x4_t or uint16x8_t vector neets to be bitcasted to a bfloat16x4_t +// or a bfloat16x8_t vector for the vbfdot_f32 and vbfdotq_f32 intrinsics if +// HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true +static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(uint16x4_t raw) { + return vreinterpret_bf16_u16(raw); +} +static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(uint16x8_t raw) { + return vreinterpretq_bf16_u16(raw); +} +#endif +} // namespace detail + +template +HWY_API Vec128 MulEvenAdd(D /*d32*/, Vec128 a, + Vec128 b, const Vec128 c) { + return Vec128(vbfmlalbq_f32(c.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API Vec128 MulOddAdd(D /*d32*/, Vec128 a, + Vec128 b, const Vec128 c) { + return Vec128(vbfmlaltq_f32(c.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} template HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, Vec128 a, Vec128 b, const Vec128 sum0, Vec128& /*sum1*/) { - return Vec128(vbfdotq_f32(sum0.raw, a.raw, b.raw)); + return Vec128(vbfdotq_f32(sum0.raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +// There is no non-q version of these instructions. +template +HWY_API VFromD MulEvenAdd(D d32, VFromD> a, + VFromD> b, + const VFromD c) { + const Full128 d32f; + const Full128 d16f; + return ResizeBitCast( + d32, MulEvenAdd(d32f, ResizeBitCast(d16f, a), ResizeBitCast(d16f, b), + ResizeBitCast(d32f, c))); +} + +template +HWY_API VFromD MulOddAdd(D d32, VFromD> a, + VFromD> b, + const VFromD c) { + const Full128 d32f; + const Full128 d16f; + return ResizeBitCast( + d32, MulOddAdd(d32f, ResizeBitCast(d16f, a), ResizeBitCast(d16f, b), + ResizeBitCast(d32f, c))); } template @@ -5835,28 +7297,11 @@ HWY_API VFromD ReorderWidenMulAccumulate( D /*d32*/, VFromD> a, VFromD> b, const VFromD sum0, VFromD& /*sum1*/) { - return VFromD(vbfdot_f32(sum0.raw, a.raw, b.raw)); -} - -#else - -template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, - const VFromD sum0, - VFromD& sum1) { - const RebindToUnsigned du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); + return VFromD(vbfdot_f32(sum0.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); } -#endif // HWY_NEON_HAVE_BFLOAT16 +#endif // HWY_NEON_HAVE_F32_TO_BF16C template HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, Vec128 a, @@ -6024,39 +7469,108 @@ HWY_API Vec32 RearrangeToOddPlusEven(Vec32 sum0, return sum0 + sum1; } +// ------------------------------ SumOfMulQuadAccumulate + +#if HWY_TARGET == HWY_NEON_BF16 + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vdot_s32(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vdotq_s32(sum.raw, a.raw, b.raw)); +} + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD(vdot_u32(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD(vdotq_u32(sum.raw, a.raw, b.raw)); +} + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 di32, VFromD> a_u, + VFromD> b_i, VFromD sum) { + // TODO: use vusdot[q]_s32 on NEON targets that require support for NEON I8MM + + const RebindToUnsigned du32; + const Repartition du8; + + const auto b_u = BitCast(du8, b_i); + const auto result_sum0 = + SumOfMulQuadAccumulate(du32, a_u, b_u, BitCast(du32, sum)); + const auto result_sum1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(du32, a_u, ShiftRight<7>(b_u), Zero(du32))); + + return BitCast(di32, Sub(result_sum0, result_sum1)); +} + +#endif // HWY_TARGET == HWY_NEON_BF16 + // ------------------------------ WidenMulPairwiseAdd -#if HWY_NEON_HAVE_BFLOAT16 +#if HWY_NEON_HAVE_F32_TO_BF16C -template -HWY_API Vec128 WidenMulPairwiseAdd(D d32, Vec128 a, +template +HWY_API Vec128 WidenMulPairwiseAdd(DF df, Vec128 a, Vec128 b) { - return Vec128(vbfdotq_f32(Zero(d32).raw, a.raw, b.raw)); + return Vec128(vbfdotq_f32(Zero(df).raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); } -template -HWY_API VFromD WidenMulPairwiseAdd(D d32, - VFromD> a, - VFromD> b) { - return VFromD(vbfdot_f32(Zero(d32).raw, a.raw, b.raw)); +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, + VFromD> a, + VFromD> b) { + return VFromD(vbfdot_f32(Zero(df).raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); } #else -template -HWY_API VFromD WidenMulPairwiseAdd( - D32 df32, VFromD> a, - VFromD> b) { - const RebindToUnsigned du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), - Mul(BitCast(df32, ao), BitCast(df32, bo))); +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, + VFromD> a, + VFromD> b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); } -#endif // HWY_NEON_HAVE_BFLOAT16 +#endif // HWY_NEON_HAVE_F32_TO_BF16C template HWY_API Vec128 WidenMulPairwiseAdd(D /*d32*/, Vec128 a, @@ -6266,6 +7780,23 @@ namespace detail { // There is no vuzpq_u64. HWY_NEON_DEF_FUNCTION_UIF_8_16_32(ConcatEven, vuzp1, _, 2) HWY_NEON_DEF_FUNCTION_UIF_8_16_32(ConcatOdd, vuzp2, _, 2) + +#if !HWY_HAVE_FLOAT16 +template +HWY_INLINE Vec128 ConcatEven(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatEven(BitCast(du, hi), BitCast(du, lo))); +} +template +HWY_INLINE Vec128 ConcatOdd(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatOdd(BitCast(du, hi), BitCast(du, lo))); +} +#endif // !HWY_HAVE_FLOAT16 } // namespace detail // Full/half vector @@ -6374,6 +7905,36 @@ HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { return IfThenElse(MaskFromVec(vec), b, a); } +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(a, b); +#else + return VFromD(detail::InterleaveEvenOdd(a.raw, b.raw).val[0]); +#endif +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(a, b); +#else + return VFromD(detail::InterleaveEvenOdd(a.raw, b.raw).val[1]); +#endif +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + // ------------------------------ OddEvenBlocks template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { @@ -6395,12 +7956,14 @@ HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { // ------------------------------ ReorderDemote2To (OddEven) -template >> -HWY_API VFromD ReorderDemote2To(D dbf16, V32 a, V32 b) { - const RebindToUnsigned du16; - return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); +#if HWY_NEON_HAVE_F32_TO_BF16C +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const Half dh_bf16; + return Combine(dbf16, DemoteTo(dh_bf16, b), DemoteTo(dh_bf16, a)); } +#endif // HWY_NEON_HAVE_F32_TO_BF16C template HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, @@ -6616,16 +8179,19 @@ HWY_API VFromD OrderedDemote2To(D d, V a, V b) { return ReorderDemote2To(d, a, b); } -template >> -HWY_API VFromD OrderedDemote2To(D dbf16, V32 a, V32 b) { +#if HWY_NEON_HAVE_F32_TO_BF16C +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { return ReorderDemote2To(dbf16, a, b); } +#endif // HWY_NEON_HAVE_F32_TO_BF16C // ================================================== CRYPTO // (aarch64 or Arm7) and (__ARM_FEATURE_AES or HWY_HAVE_RUNTIME_DISPATCH). // Otherwise, rely on generic_ops-inl.h to emulate AESRound / CLMul*. -#if HWY_TARGET == HWY_NEON +#if HWY_TARGET != HWY_NEON_WITHOUT_AES #ifdef HWY_NATIVE_AES #undef HWY_NATIVE_AES @@ -6676,7 +8242,7 @@ HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); } -#endif // HWY_TARGET == HWY_NEON +#endif // HWY_TARGET != HWY_NEON_WITHOUT_AES // ================================================== MISC @@ -6851,10 +8417,11 @@ HWY_API Vec128 MulEven(Vec128 a, vget_low_u64(vmull_u32(a_packed, b_packed))); } -HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { - uint64_t hi; - uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 0), vgetq_lane_u64(b.raw, 0), &hi); - return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +template +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { + T hi; + T lo = Mul128(GetLane(a), GetLane(b), &hi); + return Dup128VecFromValues(Full128(), lo, hi); } // Multiplies odd lanes (1, 3 ..) and places the double-wide result into @@ -6957,10 +8524,11 @@ HWY_API Vec128 MulOdd(Vec128 a, vget_low_u64(vmull_u32(a_packed, b_packed))); } -HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { - uint64_t hi; - uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 1), vgetq_lane_u64(b.raw, 1), &hi); - return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +template +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { + T hi; + T lo = Mul128(detail::GetLane<1>(a), detail::GetLane<1>(b), &hi); + return Dup128VecFromValues(Full128(), lo, hi); } // ------------------------------ TableLookupBytes (Combine, LowerHalf) @@ -7025,7 +8593,7 @@ HWY_API VI TableLookupBytesOr0(V bytes, VI from) { // ---------------------------- AESKeyGenAssist (AESLastRound, TableLookupBytes) -#if HWY_TARGET == HWY_NEON +#if HWY_TARGET != HWY_NEON_WITHOUT_AES template HWY_API Vec128 AESKeyGenAssist(Vec128 v) { alignas(16) static constexpr uint8_t kRconXorMask[16] = { @@ -7038,51 +8606,26 @@ HWY_API Vec128 AESKeyGenAssist(Vec128 v) { const auto sub_word_result = AESLastRound(w13, Load(d, kRconXorMask)); return TableLookupBytes(sub_word_result, Load(d, kRotWordShuffle)); } -#endif // HWY_TARGET == HWY_NEON +#endif // HWY_TARGET != HWY_NEON_WITHOUT_AES // ------------------------------ Scatter in generic_ops-inl.h // ------------------------------ Gather in generic_ops-inl.h // ------------------------------ Reductions -namespace detail { - -// N=1 for any T: no-op -template -HWY_INLINE T ReduceMin(hwy::SizeTag /* tag */, Vec128 v) { - return GetLane(v); -} -template -HWY_INLINE T ReduceMax(hwy::SizeTag /* tag */, Vec128 v) { - return GetLane(v); -} -template -HWY_INLINE T ReduceSum(hwy::SizeTag /* tag */, Vec128 v) { - return GetLane(v); -} -template -HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, - Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, - Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, - Vec128 v) { - return v; -} - -// full vectors +// On Armv8 we define ReduceSum and generic_ops defines SumOfLanes via Set. #if HWY_ARCH_ARM_A64 +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + // TODO(janwas): use normal HWY_NEON_DEF, then FULL type list. #define HWY_NEON_DEF_REDUCTION(type, size, name, prefix, infix, suffix) \ - HWY_API type##_t name(hwy::SizeTag, \ - Vec128 v) { \ + template \ + HWY_API type##_t name(D /* tag */, Vec128 v) { \ return HWY_NEON_EVAL(prefix##infix##suffix, v.raw); \ } @@ -7125,83 +8668,110 @@ HWY_NEON_DEF_REDUCTION_F16(ReduceMax, vmaxv) HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceSum, vaddv) HWY_NEON_DEF_REDUCTION_UI64(ReduceSum, vaddv) +// Emulate missing UI64 and partial N=2. +template +HWY_API TFromD ReduceSum(D /* tag */, VFromD v10) { + return GetLane(v10) + ExtractLane(v10, 1); +} + +template +HWY_API TFromD ReduceMin(D /* tag */, VFromD v10) { + return HWY_MIN(GetLane(v10), ExtractLane(v10, 1)); +} + +template +HWY_API TFromD ReduceMax(D /* tag */, VFromD v10) { + return HWY_MAX(GetLane(v10), ExtractLane(v10, 1)); +} + #if HWY_HAVE_FLOAT16 -HWY_API float16_t ReduceSum(hwy::SizeTag<2>, Vec64 v) { +template +HWY_API float16_t ReduceMin(D d, VFromD v10) { + return GetLane(Min(v10, Reverse2(d, v10))); +} + +template +HWY_API float16_t ReduceMax(D d, VFromD v10) { + return GetLane(Max(v10, Reverse2(d, v10))); +} + +template +HWY_API float16_t ReduceSum(D /* tag */, VFromD v) { const float16x4_t x2 = vpadd_f16(v.raw, v.raw); - return GetLane(Vec64(vpadd_f16(x2, x2))); + return GetLane(VFromD(vpadd_f16(x2, x2))); } -HWY_API float16_t ReduceSum(hwy::SizeTag<2> tag, Vec128 v) { - return ReduceSum(tag, LowerHalf(Vec128(vpaddq_f16(v.raw, v.raw)))); +template +HWY_API float16_t ReduceSum(D d, VFromD v) { + const Half dh; + return ReduceSum(dh, LowerHalf(dh, VFromD(vpaddq_f16(v.raw, v.raw)))); } -#endif +#endif // HWY_HAVE_FLOAT16 #undef HWY_NEON_DEF_REDUCTION_CORE_TYPES #undef HWY_NEON_DEF_REDUCTION_F16 #undef HWY_NEON_DEF_REDUCTION_UI64 #undef HWY_NEON_DEF_REDUCTION -// Need some fallback implementations for [ui]64x2 and [ui]16x2. -#define HWY_IF_SUM_REDUCTION(T) HWY_IF_T_SIZE_ONE_OF(T, 1 << 2) -#define HWY_IF_MINMAX_REDUCTION(T) HWY_IF_T_SIZE_ONE_OF(T, (1 << 8) | (1 << 2)) +// ------------------------------ SumOfLanes -// Implement Min/Max/SumOfLanes in terms of the corresponding reduction. -template -HWY_API V MinOfLanes(hwy::SizeTag tag, V v) { - return Set(DFromV(), ReduceMin(tag, v)); +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); } -template -HWY_API V MaxOfLanes(hwy::SizeTag tag, V v) { - return Set(DFromV(), ReduceMax(tag, v)); +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); } -template -HWY_API V SumOfLanes(hwy::SizeTag tag, V v) { - return Set(DFromV(), ReduceSum(tag, v)); +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); } -#else +// On Armv7 we define SumOfLanes and generic_ops defines ReduceSum via GetLane. +#else // !HWY_ARCH_ARM_A64 + +// Armv7 lacks N=2 and 8-bit x4, so enable generic versions of those. +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 2) || \ + (sizeof(TFromD) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \ + nullptr +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 2) || \ + (sizeof(TFromD) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \ + nullptr // For arm7, we implement reductions using a series of pairwise operations. This // produces the full vector result, so we express Reduce* in terms of *OfLanes. #define HWY_NEON_BUILD_TYPE_T(type, size) type##x##size##_t -#define HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) Vec128 #define HWY_NEON_DEF_PAIRWISE_REDUCTION(type, size, name, prefix, suffix) \ - HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) name##OfLanes( \ - hwy::SizeTag, Vec128 v) { \ + template \ + HWY_API Vec128 name##OfLanes(D /* d */, \ + Vec128 v) { \ HWY_NEON_BUILD_TYPE_T(type, size) tmp = prefix##_##suffix(v.raw, v.raw); \ if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ - return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size)(tmp); \ - } \ - HWY_API type##_t Reduce##name(hwy::SizeTag tag, \ - Vec128 v) { \ - return GetLane(name##OfLanes(tag, v)); \ + return Vec128(tmp); \ } // For the wide versions, the pairwise operations produce a half-length vector. -// We produce that value with a Reduce*Vector helper method, and express Reduce* -// and *OfLanes in terms of the helper. +// We produce that `tmp` and then Combine. #define HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(type, size, half, name, prefix, \ suffix) \ - HWY_API HWY_NEON_BUILD_TYPE_T(type, half) \ - Reduce##name##Vector(Vec128 v) { \ + template \ + HWY_API Vec128 name##OfLanes(D /* d */, \ + Vec128 v) { \ HWY_NEON_BUILD_TYPE_T(type, half) tmp; \ tmp = prefix##_##suffix(vget_high_##suffix(v.raw), \ vget_low_##suffix(v.raw)); \ if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ if ((size / 8) > 1) tmp = prefix##_##suffix(tmp, tmp); \ - return tmp; \ - } \ - HWY_API type##_t Reduce##name(hwy::SizeTag, \ - Vec128 v) { \ - const HWY_NEON_BUILD_TYPE_T(type, half) tmp = Reduce##name##Vector(v); \ - return HWY_NEON_EVAL(vget_lane_##suffix, tmp, 0); \ - } \ - HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) name##OfLanes( \ - hwy::SizeTag, Vec128 v) { \ - const HWY_NEON_BUILD_TYPE_T(type, half) tmp = Reduce##name##Vector(v); \ - return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ - type, size)(vcombine_##suffix(tmp, tmp)); \ + return Vec128(vcombine_##suffix(tmp, tmp)); \ } #define HWY_NEON_DEF_PAIRWISE_REDUCTIONS(name, prefix) \ @@ -7227,56 +8797,22 @@ HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Max, vpmax) #undef HWY_NEON_DEF_PAIRWISE_REDUCTIONS #undef HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION #undef HWY_NEON_DEF_PAIRWISE_REDUCTION -#undef HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION #undef HWY_NEON_BUILD_TYPE_T -// Need fallback min/max implementations for [ui]64x2 and [ui]16x2. -#define HWY_IF_SUM_REDUCTION(T) HWY_IF_T_SIZE_ONE_OF(T, 1 << 2 | 1 << 8) -#define HWY_IF_MINMAX_REDUCTION(T) HWY_IF_T_SIZE_ONE_OF(T, 1 << 2 | 1 << 8) - +// GetLane(SumsOf4(v)) is more efficient on ArmV7 NEON than the default +// N=4 I8/U8 ReduceSum implementation in generic_ops-inl.h +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 #endif -} // namespace detail - -// [ui]16/[ui]64: N=2 -- special case for pairs of very small or large lanes -template -HWY_API Vec128 SumOfLanes(D /* tag */, Vec128 v10) { - return v10 + Reverse2(Simd(), v10); +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return static_cast>(GetLane(SumsOf4(v))); } -template -HWY_API T ReduceSum(D d, Vec128 v10) { - return GetLane(SumOfLanes(d, v10)); -} - -template -HWY_API Vec128 MinOfLanes(D /* tag */, Vec128 v10) { - return Min(v10, Reverse2(Simd(), v10)); -} -template -HWY_API Vec128 MaxOfLanes(D /* tag */, Vec128 v10) { - return Max(v10, Reverse2(Simd(), v10)); -} - -#undef HWY_IF_SUM_REDUCTION -#undef HWY_IF_MINMAX_REDUCTION - -template -HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { - return detail::SumOfLanes(hwy::SizeTag)>(), v); -} -template -HWY_API TFromD ReduceSum(D /* tag */, VFromD v) { - return detail::ReduceSum(hwy::SizeTag)>(), v); -} -template -HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { - return detail::MinOfLanes(hwy::SizeTag)>(), v); -} -template -HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { - return detail::MaxOfLanes(hwy::SizeTag)>(), v); -} +#endif // HWY_ARCH_ARM_A64 // ------------------------------ LoadMaskBits (TestBit) @@ -7345,6 +8881,15 @@ HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { return detail::LoadMaskBits(d, mask_bits); } +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + // ------------------------------ Mask namespace detail { @@ -7674,7 +9219,7 @@ namespace detail { template HWY_INLINE Vec128 Load8Bytes(D /*tag*/, const uint8_t* bytes) { return Vec128(vreinterpretq_u8_u64( - vld1q_dup_u64(reinterpret_cast(bytes)))); + vld1q_dup_u64(HWY_RCAST_ALIGNED(const uint64_t*, bytes)))); } // Load 8 bytes and return half-reg with N <= 8 bytes. @@ -8234,17 +9779,22 @@ namespace detail { #define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from #if HWY_ARCH_ARM_A64 -#define HWY_IF_LOAD_INT(D) HWY_IF_V_SIZE_GT_D(D, 4) -#define HWY_NEON_DEF_FUNCTION_LOAD_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#define HWY_IF_LOAD_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D) +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) #else -// Exclude 64x2 and f64x1, which are only supported on aarch64 +// Exclude 64x2 and f64x1, which are only supported on aarch64; also exclude any +// emulated types. #define HWY_IF_LOAD_INT(D) \ - HWY_IF_V_SIZE_GT_D(D, 4), \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D), \ hwy::EnableIf<(HWY_MAX_LANES_D(D) == 1 || sizeof(TFromD) < 8)>* = \ nullptr #define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) @@ -8287,40 +9837,36 @@ HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) template > HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1) { - auto raw = detail::LoadInterleaved2( - reinterpret_cast*>(unaligned), - detail::Tuple2()); + auto raw = detail::LoadInterleaved2(detail::NativeLanePointer(unaligned), + detail::Tuple2()); v0 = VFromD(raw.val[0]); v1 = VFromD(raw.val[1]); } // <= 32 bits: avoid loading more than N bytes by copying to buffer -template > +template > HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1) { // The smallest vector registers are 64-bits and we want space for two. alignas(16) T buf[2 * 8 / sizeof(T)] = {}; CopyBytes(unaligned, buf); - auto raw = detail::LoadInterleaved2( - reinterpret_cast*>(buf), - detail::Tuple2()); + auto raw = detail::LoadInterleaved2(detail::NativeLanePointer(buf), + detail::Tuple2()); v0 = VFromD(raw.val[0]); v1 = VFromD(raw.val[1]); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 -template , HWY_IF_T_SIZE(T, 8)> +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> HWY_API void LoadInterleaved2(D d, T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1) { const Half dh; VFromD v00, v10, v01, v11; - LoadInterleaved2( - dh, reinterpret_cast*>(unaligned), v00, - v10); - LoadInterleaved2( - dh, reinterpret_cast*>(unaligned + 2), - v01, v11); + LoadInterleaved2(dh, detail::NativeLanePointer(unaligned), v00, v10); + LoadInterleaved2(dh, detail::NativeLanePointer(unaligned + 2), v01, v11); v0 = Combine(d, v01, v00); v1 = Combine(d, v11, v10); } @@ -8331,24 +9877,23 @@ HWY_API void LoadInterleaved2(D d, T* HWY_RESTRICT unaligned, Vec128& v0, template > HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2) { - auto raw = detail::LoadInterleaved3( - reinterpret_cast*>(unaligned), - detail::Tuple3()); + auto raw = detail::LoadInterleaved3(detail::NativeLanePointer(unaligned), + detail::Tuple3()); v0 = VFromD(raw.val[0]); v1 = VFromD(raw.val[1]); v2 = VFromD(raw.val[2]); } // <= 32 bits: avoid writing more than N bytes by copying to buffer -template > +template > HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2) { // The smallest vector registers are 64-bits and we want space for three. alignas(16) T buf[3 * 8 / sizeof(T)] = {}; CopyBytes(unaligned, buf); - auto raw = detail::LoadInterleaved3( - reinterpret_cast*>(buf), - detail::Tuple3()); + auto raw = detail::LoadInterleaved3(detail::NativeLanePointer(buf), + detail::Tuple3()); v0 = VFromD(raw.val[0]); v1 = VFromD(raw.val[1]); v2 = VFromD(raw.val[2]); @@ -8356,17 +9901,14 @@ HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 -template , HWY_IF_T_SIZE(T, 8)> +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2) { const Half dh; VFromD v00, v10, v20, v01, v11, v21; - LoadInterleaved3( - dh, reinterpret_cast*>(unaligned), v00, - v10, v20); - LoadInterleaved3( - dh, reinterpret_cast*>(unaligned + 3), - v01, v11, v21); + LoadInterleaved3(dh, detail::NativeLanePointer(unaligned), v00, v10, v20); + LoadInterleaved3(dh, detail::NativeLanePointer(unaligned + 3), v01, v11, v21); v0 = Combine(d, v01, v00); v1 = Combine(d, v11, v10); v2 = Combine(d, v21, v20); @@ -8379,9 +9921,8 @@ template > HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2, VFromD& v3) { - auto raw = detail::LoadInterleaved4( - reinterpret_cast*>(unaligned), - detail::Tuple4()); + auto raw = detail::LoadInterleaved4(detail::NativeLanePointer(unaligned), + detail::Tuple4()); v0 = VFromD(raw.val[0]); v1 = VFromD(raw.val[1]); v2 = VFromD(raw.val[2]); @@ -8389,15 +9930,15 @@ HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, } // <= 32 bits: avoid writing more than N bytes by copying to buffer -template > +template > HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2, VFromD& v3) { alignas(16) T buf[4 * 8 / sizeof(T)] = {}; CopyBytes(unaligned, buf); - auto raw = detail::LoadInterleaved4( - reinterpret_cast*>(buf), - detail::Tuple4()); + auto raw = detail::LoadInterleaved4(detail::NativeLanePointer(buf), + detail::Tuple4()); v0 = VFromD(raw.val[0]); v1 = VFromD(raw.val[1]); v2 = VFromD(raw.val[2]); @@ -8406,18 +9947,17 @@ HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 -template , HWY_IF_T_SIZE(T, 8)> +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, Vec128& v0, Vec128& v1, Vec128& v2, Vec128& v3) { const Half dh; VFromD v00, v10, v20, v30, v01, v11, v21, v31; - LoadInterleaved4( - dh, reinterpret_cast*>(unaligned), v00, - v10, v20, v30); - LoadInterleaved4( - dh, reinterpret_cast*>(unaligned + 4), - v01, v11, v21, v31); + LoadInterleaved4(dh, detail::NativeLanePointer(unaligned), v00, v10, v20, + v30); + LoadInterleaved4(dh, detail::NativeLanePointer(unaligned + 4), v01, v11, v21, + v31); v0 = Combine(d, v01, v00); v1 = Combine(d, v11, v10); v2 = Combine(d, v21, v20); @@ -8435,17 +9975,22 @@ namespace detail { #define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw #if HWY_ARCH_ARM_A64 -#define HWY_IF_STORE_INT(D) HWY_IF_V_SIZE_GT_D(D, 4) -#define HWY_NEON_DEF_FUNCTION_STORE_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#define HWY_IF_STORE_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D) +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) #else -// Exclude 64x2 and f64x1, which are only supported on aarch64 +// Exclude 64x2 and f64x1, which are only supported on aarch64; also exclude any +// emulated types. #define HWY_IF_STORE_INT(D) \ - HWY_IF_V_SIZE_GT_D(D, 4), \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D), \ hwy::EnableIf<(HWY_MAX_LANES_D(D) == 1 || sizeof(TFromD) < 8)>* = \ nullptr #define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) @@ -8476,32 +10021,31 @@ template > HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, T* HWY_RESTRICT unaligned) { detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; - detail::StoreInterleaved2( - tup, reinterpret_cast*>(unaligned)); + detail::StoreInterleaved2(tup, detail::NativeLanePointer(unaligned)); } // <= 32 bits: avoid writing more than N bytes by copying to buffer -template > +template > HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, T* HWY_RESTRICT unaligned) { alignas(16) T buf[2 * 8 / sizeof(T)]; detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; - detail::StoreInterleaved2(tup, - reinterpret_cast*>(buf)); + detail::StoreInterleaved2(tup, detail::NativeLanePointer(buf)); CopyBytes(buf, unaligned); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 -template , HWY_IF_T_SIZE(T, 8)> +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved2(Vec128 v0, Vec128 v1, D d, T* HWY_RESTRICT unaligned) { const Half dh; StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, - reinterpret_cast*>(unaligned)); - StoreInterleaved2( - UpperHalf(dh, v0), UpperHalf(dh, v1), dh, - reinterpret_cast*>(unaligned + 2)); + detail::NativeLanePointer(unaligned)); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, + detail::NativeLanePointer(unaligned + 2)); } #endif // HWY_ARCH_ARM_V7 @@ -8511,32 +10055,31 @@ template > HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, T* HWY_RESTRICT unaligned) { detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; - detail::StoreInterleaved3( - tup, reinterpret_cast*>(unaligned)); + detail::StoreInterleaved3(tup, detail::NativeLanePointer(unaligned)); } // <= 32 bits: avoid writing more than N bytes by copying to buffer -template > +template > HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, T* HWY_RESTRICT unaligned) { alignas(16) T buf[3 * 8 / sizeof(T)]; detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; - detail::StoreInterleaved3(tup, - reinterpret_cast*>(buf)); + detail::StoreInterleaved3(tup, detail::NativeLanePointer(buf)); CopyBytes(buf, unaligned); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 -template , HWY_IF_T_SIZE(T, 8)> +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved3(Vec128 v0, Vec128 v1, Vec128 v2, D d, T* HWY_RESTRICT unaligned) { const Half dh; StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, - reinterpret_cast*>(unaligned)); - StoreInterleaved3( - UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, - reinterpret_cast*>(unaligned + 3)); + detail::NativeLanePointer(unaligned)); + StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, + detail::NativeLanePointer(unaligned + 3)); } #endif // HWY_ARCH_ARM_V7 @@ -8546,39 +10089,41 @@ template > HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, VFromD v3, D d, T* HWY_RESTRICT unaligned) { detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; - detail::StoreInterleaved4( - tup, reinterpret_cast*>(unaligned)); + detail::StoreInterleaved4(tup, detail::NativeLanePointer(unaligned)); } // <= 32 bits: avoid writing more than N bytes by copying to buffer -template > +template > HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, VFromD v3, D d, T* HWY_RESTRICT unaligned) { alignas(16) T buf[4 * 8 / sizeof(T)]; detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; - detail::StoreInterleaved4(tup, - reinterpret_cast*>(buf)); + detail::StoreInterleaved4(tup, detail::NativeLanePointer(buf)); CopyBytes(buf, unaligned); } #if HWY_ARCH_ARM_V7 // 64x2: split into two 64x1 -template , HWY_IF_T_SIZE(T, 8)> +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved4(Vec128 v0, Vec128 v1, Vec128 v2, Vec128 v3, D d, T* HWY_RESTRICT unaligned) { const Half dh; StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), LowerHalf(dh, v3), dh, - reinterpret_cast*>(unaligned)); - StoreInterleaved4( - UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), - UpperHalf(dh, v3), dh, - reinterpret_cast*>(unaligned + 4)); + detail::NativeLanePointer(unaligned)); + StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), + UpperHalf(dh, v3), dh, + detail::NativeLanePointer(unaligned + 4)); } #endif // HWY_ARCH_ARM_V7 #undef HWY_IF_STORE_INT +// Fall back on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_NEON_IF_EMULATED_D. + // ------------------------------ Additional mask logical operations template HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { @@ -8904,7 +10449,8 @@ namespace detail { // for code folding #undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 #undef HWY_NEON_DEF_FUNCTION_UINTS #undef HWY_NEON_EVAL - +#undef HWY_NEON_IF_EMULATED_D +#undef HWY_NEON_IF_NOT_EMULATED_D } // namespace detail // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/r/src/vendor/highway/hwy/ops/arm_sve-inl.h b/r/src/vendor/highway/hwy/ops/arm_sve-inl.h index 7f7c8cf2..2dde1479 100644 --- a/r/src/vendor/highway/hwy/ops/arm_sve-inl.h +++ b/r/src/vendor/highway/hwy/ops/arm_sve-inl.h @@ -33,6 +33,33 @@ #define HWY_SVE_HAVE_2 0 #endif +// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available: +// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip. +#if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16) +#define HWY_SVE_HAVE_BF16_FEATURE 1 +#else +#define HWY_SVE_HAVE_BF16_FEATURE 0 +#endif + +// HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type +// is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0. +#if HWY_SVE_HAVE_BF16_FEATURE || \ + (HWY_COMPILER_CLANG >= 1200 && defined(__ARM_FEATURE_SVE_BF16)) || \ + HWY_COMPILER_GCC_ACTUAL >= 1000 +#define HWY_SVE_HAVE_BF16_VEC 1 +#else +#define HWY_SVE_HAVE_BF16_VEC 0 +#endif + +// HWY_SVE_HAVE_F32_TO_BF16C is defined to 1 if the SVE svcvt_bf16_f32_x +// and svcvtnt_bf16_f32_x intrinsics are available, even if the __bf16 type +// is disabled +#if HWY_SVE_HAVE_BF16_VEC && defined(__ARM_FEATURE_SVE_BF16) +#define HWY_SVE_HAVE_F32_TO_BF16C 1 +#else +#define HWY_SVE_HAVE_F32_TO_BF16C 0 +#endif + HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { @@ -76,12 +103,29 @@ namespace detail { // for code folding #define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ X_MACRO(float, f, 64, 32, NAME, OP) -#if HWY_SVE_HAVE_BFLOAT16 -#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \ +#define HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) \ X_MACRO(bfloat, bf, 16, 16, NAME, OP) + +#if HWY_SVE_HAVE_BF16_FEATURE +#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) +// We have both f16 and bf16, so nothing is emulated. + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the D template +// argument +#define HWY_SVE_IF_EMULATED_D(D) hwy::EnableIf()>* = nullptr +#define HWY_GENERIC_IF_EMULATED_D(D) \ + hwy::EnableIf()>* = nullptr +#define HWY_SVE_IF_NOT_EMULATED_D(D) hwy::EnableIf* = nullptr #else #define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) -#endif +#define HWY_SVE_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_SVE_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#endif // HWY_SVE_HAVE_BF16_FEATURE // For all element sizes: #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ @@ -96,12 +140,16 @@ namespace detail { // for code folding HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) +#define HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + // HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks // bf16 overloads for some intrinsics (especially less-common arithmetic). +// However, this does include f16 because SVE supports it unconditionally. #define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ - HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ - HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) // Commonly used type categories for a given element size: #define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ @@ -123,8 +171,7 @@ namespace detail { // for code folding #define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ - HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ - HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) // Commonly used type categories: #define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ @@ -155,7 +202,9 @@ namespace detail { // for code folding }; HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) -HWY_SVE_FOREACH_BF16(HWY_SPECIALIZE, _, _) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) +#endif #undef HWY_SPECIALIZE // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX @@ -184,15 +233,24 @@ HWY_SVE_FOREACH_BF16(HWY_SPECIALIZE, _, _) } // vector = f(vector, vector), e.g. Add +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } +// All-true mask #define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ } -#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ - return sv##OP##_##CHAR##BITS(a, b); \ +// User-specified mask. Mask=false value is undefined and must be set by caller +// because SVE instructions take it from one of the two inputs, whereas +// AVX-512, RVV and Highway allow a third argument. +#define HWY_SVE_RETV_ARGMVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(m, a, b); \ } #define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ @@ -264,26 +322,19 @@ HWY_API size_t Lanes(Simd d) { return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(limit)); \ } HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) -HWY_SVE_FOREACH_BF16(HWY_SVE_FIRSTN, FirstN, whilelt) - -#undef HWY_SVE_FIRSTN - -template -using MFromD = decltype(FirstN(D(), 0)); +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_FIRSTN, FirstN, whilelt) +#endif -#if !HWY_HAVE_FLOAT16 -template -MFromD> FirstN(D /* tag */, size_t count) { +template +svbool_t FirstN(D /* tag */, size_t count) { return FirstN(RebindToUnsigned(), count); } -#endif // !HWY_HAVE_FLOAT16 -#if !HWY_SVE_HAVE_BFLOAT16 -template -MFromD> FirstN(D /* tag */, size_t count) { - return FirstN(RebindToUnsigned(), count); -} -#endif // !HWY_SVE_HAVE_BFLOAT16 +#undef HWY_SVE_FIRSTN + +template +using MFromD = svbool_t; namespace detail { @@ -298,7 +349,7 @@ namespace detail { } HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true -HWY_SVE_FOREACH_BF16(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) #undef HWY_SVE_WRAP_PTRUE HWY_API svbool_t PFalse() { return svpfalse_b(); } @@ -314,6 +365,17 @@ svbool_t MakeMask(D d) { } // namespace detail +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API svbool_t MaskFalse(const D /*d*/) { + return detail::PFalse(); +} + // ================================================== INIT // ------------------------------ Set @@ -326,14 +388,23 @@ svbool_t MakeMask(D d) { } HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) +#if HWY_SVE_HAVE_BF16_FEATURE // for if-elif chain HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n) -#if !HWY_SVE_HAVE_BFLOAT16 +#elif HWY_SVE_HAVE_BF16_VEC // Required for Zero and VFromD -template -svuint16_t Set(Simd d, bfloat16_t arg) { - return Set(RebindToUnsigned(), arg.bits); +template +HWY_API svbfloat16_t Set(D d, bfloat16_t arg) { + return svreinterpret_bf16_u16( + Set(RebindToUnsigned(), BitCastScalar(arg))); } -#endif // HWY_SVE_HAVE_BFLOAT16 +#else // neither bf16 feature nor vector: emulate with u16 +// Required for Zero and VFromD +template +HWY_API svuint16_t Set(D d, bfloat16_t arg) { + const RebindToUnsigned du; + return Set(du, BitCastScalar(arg)); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE #undef HWY_SVE_SET template @@ -350,17 +421,6 @@ VFromD Zero(D d) { return BitCast(d, Set(du, 0)); } -// ------------------------------ Undefined - -#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ - return sv##OP##_##CHAR##BITS(); \ - } - -HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) - // ------------------------------ BitCast namespace detail { @@ -387,24 +447,32 @@ namespace detail { return sv##OP##_##CHAR##BITS##_u8(v); \ } +// U08 is special-cased, hence do not use FOREACH. HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) -HWY_SVE_FOREACH_BF16(HWY_SVE_CAST, _, reinterpret) -#undef HWY_SVE_CAST_NOP -#undef HWY_SVE_CAST +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CAST, _, reinterpret) +#else // !(HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC) +template )> +HWY_INLINE svuint8_t BitCastToByte(V v) { + const RebindToUnsigned> du; + return BitCastToByte(BitCast(du, v)); +} -#if !HWY_SVE_HAVE_BFLOAT16 -template -HWY_INLINE VBF16 BitCastFromByte(Simd /* d */, - svuint8_t v) { - return BitCastFromByte(Simd(), v); +template +HWY_INLINE VFromD BitCastFromByte(D d, svuint8_t v) { + const RebindToUnsigned du; + return BitCastFromByte(du, v); } -#endif // !HWY_SVE_HAVE_BFLOAT16 +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC + +#undef HWY_SVE_CAST_NOP +#undef HWY_SVE_CAST } // namespace detail @@ -413,6 +481,26 @@ HWY_API VFromD BitCast(D d, FromV v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } +// ------------------------------ Undefined + +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_UNDEFINED, Undefined, undef) +#endif + +template +VFromD Undefined(D d) { + const RebindToUnsigned du; + return BitCast(d, Undefined(du)); +} + // ------------------------------ Tuple // tuples = f(d, v..), e.g. Create2 @@ -438,7 +526,9 @@ HWY_API VFromD BitCast(D d, FromV v) { } HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create) -HWY_SVE_FOREACH_BF16(HWY_SVE_CREATE, Create, create) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CREATE, Create, create) +#endif #undef HWY_SVE_CREATE template @@ -463,7 +553,9 @@ using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); } HWY_SVE_FOREACH(HWY_SVE_GET, Get, get) -HWY_SVE_FOREACH_BF16(HWY_SVE_GET, Get, get) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_GET, Get, get) +#endif #undef HWY_SVE_GET #define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ @@ -484,7 +576,9 @@ HWY_SVE_FOREACH_BF16(HWY_SVE_GET, Get, get) } HWY_SVE_FOREACH(HWY_SVE_SET, Set, set) -HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, set) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_SET, Set, set) +#endif #undef HWY_SVE_SET // ------------------------------ ResizeBitCast @@ -495,6 +589,107 @@ HWY_API VFromD ResizeBitCast(D d, FromV v) { return BitCast(d, v); } +// ------------------------------ Dup128VecFromValues + +template +HWY_API svint8_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return svdupq_n_s8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, + t14, t15); +} + +template +HWY_API svuint8_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return svdupq_n_u8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, + t14, t15); +} + +template +HWY_API svint16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return svdupq_n_s16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API svuint16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return svdupq_n_u16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API svfloat16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD t4, TFromD t5, + TFromD t6, TFromD t7) { + return svdupq_n_f16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API VBF16 Dup128VecFromValues(D d, TFromD t0, TFromD t1, TFromD t2, + TFromD t3, TFromD t4, TFromD t5, + TFromD t6, TFromD t7) { +#if HWY_SVE_HAVE_BF16_FEATURE + (void)d; + return svdupq_n_bf16(t0, t1, t2, t3, t4, t5, t6, t7); +#else + const RebindToUnsigned du; + return BitCast( + d, Dup128VecFromValues( + du, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +#endif +} + +template +HWY_API svint32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_s32(t0, t1, t2, t3); +} + +template +HWY_API svuint32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_u32(t0, t1, t2, t3); +} + +template +HWY_API svfloat32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_f32(t0, t1, t2, t3); +} + +template +HWY_API svint64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_s64(t0, t1); +} + +template +HWY_API svuint64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_u64(t0, t1); +} + +template +HWY_API svfloat64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_f64(t0, t1); +} + // ================================================== LOGICAL // detail::*N() functions accept a scalar argument to avoid extra Set(). @@ -519,6 +714,10 @@ HWY_API V And(const V a, const V b) { // ------------------------------ Or +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, OrN, orr_n) +} // namespace detail + HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) template @@ -632,9 +831,37 @@ HWY_API VBF16 Neg(VBF16 v) { return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); } +// ------------------------------ SaturatedNeg +#if HWY_SVE_HAVE_2 +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedNeg, qneg) +#endif // HWY_SVE_HAVE_2 + // ------------------------------ Abs HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) +// ------------------------------ SaturatedAbs +#if HWY_SVE_HAVE_2 +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) +#endif // HWY_SVE_HAVE_2 + // ================================================== ARITHMETIC // Per-target flags to prevent generic_ops-inl.h defining Add etc. @@ -676,13 +903,107 @@ HWY_API svuint64_t SumsOf8(const svuint8_t v) { const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); // Compute pairwise sum of u32 and extend to u64. - // TODO(janwas): on SVE2, we can instead use svaddp. + +#if HWY_SVE_HAVE_2 + return svadalp_u64_x(pg, Zero(du64), sums_of_4); +#else const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); return Add(hi, lo); +#endif +} + +HWY_API svint64_t SumsOf8(const svint8_t v) { + const ScalableTag di32; + const ScalableTag di64; + const svbool_t pg = detail::PTrue(di64); + + const svint32_t sums_of_4 = svdot_n_s32(Zero(di32), v, 1); +#if HWY_SVE_HAVE_2 + return svadalp_s64_x(pg, Zero(di64), sums_of_4); +#else + const svint64_t hi = svasr_n_s64_x(pg, BitCast(di64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and sign-extended) + const svint64_t lo = svextw_s64_x(pg, BitCast(di64, sums_of_4)); + return Add(hi, lo); +#endif +} + +// ------------------------------ SumsOf2 +#if HWY_SVE_HAVE_2 +namespace detail { + +HWY_INLINE svint16_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { + const ScalableTag di16; + const svbool_t pg = detail::PTrue(di16); + return svadalp_s16_x(pg, Zero(di16), v); +} + +HWY_INLINE svuint16_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { + const ScalableTag du16; + const svbool_t pg = detail::PTrue(du16); + return svadalp_u16_x(pg, Zero(du16), v); +} + +HWY_INLINE svint32_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { + const ScalableTag di32; + const svbool_t pg = detail::PTrue(di32); + return svadalp_s32_x(pg, Zero(di32), v); +} + +HWY_INLINE svuint32_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { + const ScalableTag du32; + const svbool_t pg = detail::PTrue(du32); + return svadalp_u32_x(pg, Zero(du32), v); +} + +HWY_INLINE svint64_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, svint32_t v) { + const ScalableTag di64; + const svbool_t pg = detail::PTrue(di64); + return svadalp_s64_x(pg, Zero(di64), v); } +HWY_INLINE svuint64_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, svuint32_t v) { + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + return svadalp_u64_x(pg, Zero(du64), v); +} + +} // namespace detail +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ SumsOf4 +namespace detail { + +HWY_INLINE svint32_t SumsOf4(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { + return svdot_n_s32(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svuint32_t SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { + return svdot_n_u32(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svint64_t SumsOf4(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { + return svdot_n_s64(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svuint64_t SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { + return svdot_n_u64(Zero(ScalableTag()), v, 1); +} + +} // namespace detail + // ------------------------------ SaturatedAdd #ifdef HWY_NATIVE_I32_SATURATED_ADDSUB @@ -726,14 +1047,15 @@ HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) // ------------------------------ ShiftLeft[Same] -#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ - return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ - } \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \ - return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits); \ +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Same(HWY_SVE_V(BASE, BITS) v, int bits) { \ + return sv##OP##_##CHAR##BITS##_x( \ + HWY_SVE_PTRUE(BITS), v, static_cast(bits)); \ } HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) @@ -747,15 +1069,35 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) // ------------------------------ RotateRight -// TODO(janwas): svxar on SVE2 -template +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_ROTATE_RIGHT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + if (kBits == 0) return v; \ + return sv##OP##_##CHAR##BITS(v, Zero(DFromV()), \ + HWY_MAX(kBits, 1)); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) +HWY_SVE_FOREACH_I(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) + +#undef HWY_SVE_ROTATE_RIGHT_N + +#else // !HWY_SVE_HAVE_2 +template HWY_API V RotateRight(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; - return Or(ShiftRight(v), + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), ShiftLeft(v)); } +#endif // ------------------------------ Shl/r @@ -774,6 +1116,50 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) #undef HWY_SVE_SHIFT +// ------------------------------ RoundingShiftLeft[Same]/RoundingShr + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +#define HWY_SVE_ROUNDING_SHR_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + HWY_IF_CONSTEXPR(kBits == 0) { return v; } \ + \ + return sv##OP##_##CHAR##BITS##_x( \ + HWY_SVE_PTRUE(BITS), v, static_cast(HWY_MAX(kBits, 1))); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR_N, RoundingShiftRight, rshr_n) + +#undef HWY_SVE_ROUNDING_SHR_N + +#define HWY_SVE_ROUNDING_SHR(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToSigned> di; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + Neg(BitCast(di, bits))); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR, RoundingShr, rshl) + +#undef HWY_SVE_ROUNDING_SHR + +template +HWY_API V RoundingShiftRightSame(V v, int bits) { + const DFromV d; + using T = TFromD; + return RoundingShr(v, Set(d, static_cast(bits))); +} + +#endif // HWY_SVE_HAVE_2 + // ------------------------------ Min/Max HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min) @@ -803,11 +1189,7 @@ HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul) // ------------------------------ MulHigh -HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) -// Not part of API, used internally: -HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) -HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) -HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) // ------------------------------ MulFixedPoint15 HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { @@ -830,6 +1212,14 @@ HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { } // ------------------------------ Div +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, Div, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPVV, Div, div) HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) // ------------------------------ ApproximateReciprocal @@ -981,18 +1371,40 @@ HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { } HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +HWY_SVE_FOREACH_BF16(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) #undef HWY_SVE_IF_THEN_ELSE +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenElse(const svbool_t mask, V yes, V no) { + const RebindToUnsigned du; + return BitCast( + D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + // ------------------------------ IfThenElseZero -template + +template , HWY_SVE_IF_NOT_EMULATED_D(D)> HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { - return IfThenElse(mask, yes, Zero(DFromV())); + return IfThenElse(mask, yes, Zero(D())); +} + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenElseZero(const svbool_t mask, V yes) { + const RebindToUnsigned du; + return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); } // ------------------------------ IfThenZeroElse -template + +template , HWY_SVE_IF_NOT_EMULATED_D(D)> HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { - return IfThenElse(mask, Zero(DFromV()), no); + return IfThenElse(mask, Zero(D()), no); +} + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenZeroElse(const svbool_t mask, V no) { + const RebindToUnsigned du; + return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); } // ------------------------------ Additional mask logical operations @@ -1016,6 +1428,162 @@ HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) { return Not(SetBeforeFirst(m)); } +// ------------------------------ PromoteMaskTo + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template ) * 2)> +HWY_API svbool_t PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svunpklo_b(m); +} + +template ) * 2)> +HWY_API svbool_t PromoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { + using TFrom = TFromD; + using TWFrom = MakeWide>; + static_assert(sizeof(TWFrom) > sizeof(TFrom), + "sizeof(TWFrom) > sizeof(TFrom) must be true"); + + const Rebind dw_from; + return PromoteMaskTo(d_to, dw_from, PromoteMaskTo(dw_from, d_from, m)); +} + +// ------------------------------ DemoteMaskTo + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b8(m, m); +} + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b16(m, m); +} + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b32(m, m); +} + +template ) / 4)> +HWY_API svbool_t DemoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { + using TFrom = TFromD; + using TNFrom = MakeNarrow>; + static_assert(sizeof(TNFrom) < sizeof(TFrom), + "sizeof(TNFrom) < sizeof(TFrom) must be true"); + + const Rebind dn_from; + return DemoteMaskTo(d_to, dn_from, DemoteMaskTo(dn_from, d_from, m)); +} + +// ------------------------------ LowerHalfOfMask +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { + return m; +} + +// ------------------------------ MaskedAddOr etc. (IfThenElse) + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub) +#endif +} // namespace detail + +template +HWY_API V MaskedMinOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMin(m, a, b), no); +} + +template +HWY_API V MaskedMaxOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMax(m, a, b), no); +} + +template +HWY_API V MaskedAddOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedAdd(m, a, b), no); +} + +template +HWY_API V MaskedSubOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSub(m, a, b), no); +} + +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMul(m, a, b), no); +} + +template , hwy::float16_t>() ? (1 << 2) : 0) | + (1 << 4) | (1 << 8))> +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedDiv(m, a, b), no); +} + +// I8/U8/I16/U16 MaskedDivOr is implemented after I8/U8/I16/U16 Div + +#if HWY_SVE_HAVE_2 +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSatAdd(m, a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSatSub(m, a, b), no); +} +#else +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedAdd(a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedSub(a, b), no); +} +#endif + // ================================================== COMPARE // mask = f(vector, vector) @@ -1078,7 +1646,8 @@ HWY_API svbool_t TestBit(const V a, const V bit) { // ------------------------------ MaskFromVec (Ne) template HWY_API svbool_t MaskFromVec(const V v) { - return detail::NeN(v, static_cast>(0)); + using T = TFromV; + return detail::NeN(v, ConvertScalarTo(0)); } // ------------------------------ VecFromMask @@ -1090,6 +1659,22 @@ HWY_API VFromD VecFromMask(const D d, svbool_t mask) { return BitCast(d, IfThenElseZero(mask, Set(di, -1))); } +// ------------------------------ IsNegative (Lt) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +template +HWY_API svbool_t IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + return detail::LtN(BitCast(di, v), static_cast(0)); +} + // ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) #if HWY_SVE_HAVE_2 @@ -1159,14 +1744,27 @@ HWY_API svbool_t IsNaN(const V v) { return Ne(v, v); // could also use cmpuo } +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +// We use a fused Set/comparison for IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + template HWY_API svbool_t IsInf(const V v) { using T = TFromV; const DFromV d; + const RebindToUnsigned du; const RebindToSigned di; - const VFromD vi = BitCast(di, v); - // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, detail::EqN(Add(vi, vi), hwy::MaxExponentTimes2())); + + // 'Shift left' to clear the sign bit + const VFromD vu = BitCast(du, v); + const VFromD v2 = Add(vu, vu); + // Check for exponent=max and mantissa=0. + const VFromD max2 = Set(di, hwy::MaxExponentTimes2()); + return RebindMask(d, Eq(v2, BitCast(du, max2))); } // Returns whether normal/subnormal/zero. @@ -1187,147 +1785,135 @@ HWY_API svbool_t IsFinite(const V v) { // ================================================== MEMORY -// ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream +// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream -#define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - return sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ - reinterpret_cast(p)); \ +#define HWY_SVE_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + LoadU(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(p)); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + MaskedLoad(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \ + } \ + template \ + HWY_API void StoreU(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svst1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), v); \ + } \ + template \ + HWY_API void Stream(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svstnt1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ + v); \ + } \ + template \ + HWY_API void BlendedStore(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svst1_##CHAR##BITS(m, detail::NativeLanePointer(p), v); \ } -#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - return sv##OP##_##CHAR##BITS(m, reinterpret_cast(p)); \ - } +HWY_SVE_FOREACH(HWY_SVE_MEM, _, _) +HWY_SVE_FOREACH_BF16(HWY_SVE_MEM, _, _) -#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ - const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - /* All-true predicate to load all 128 bits. */ \ - return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), \ - reinterpret_cast(p)); \ - } +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} -#define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ - HWY_SVE_D(BASE, BITS, N, kPow2) d, \ - HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), reinterpret_cast(p), v); \ - } +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); +} -#define HWY_SVE_BLENDED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ - HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ - HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - sv##OP##_##CHAR##BITS(m, reinterpret_cast(p), v); \ - } +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, + MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); +} -HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1) -HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1) -HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1) -HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1) -HWY_SVE_FOREACH(HWY_SVE_BLENDED_STORE, BlendedStore, st1) +// MaskedLoadOr is generic and does not require emulation. -HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD, Load, ld1) -HWY_SVE_FOREACH_BF16(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1) -HWY_SVE_FOREACH_BF16(HWY_SVE_STORE, Store, st1) -HWY_SVE_FOREACH_BF16(HWY_SVE_STORE, Stream, stnt1) -HWY_SVE_FOREACH_BF16(HWY_SVE_BLENDED_STORE, BlendedStore, st1) +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + BlendedStore(BitCast(du, v), RebindMask(du, m), du, + detail::U16LanePointer(p)); +} + +#undef HWY_SVE_MEM #if HWY_TARGET != HWY_SVE2_128 namespace detail { -HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) -} // namespace detail -#endif // HWY_TARGET != HWY_SVE2_128 - -#undef HWY_SVE_LOAD -#undef HWY_SVE_MASKED_LOAD -#undef HWY_SVE_LOAD_DUP128 -#undef HWY_SVE_STORE -#undef HWY_SVE_BLENDED_STORE +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), \ + detail::NativeLanePointer(p)); \ + } -#if !HWY_SVE_HAVE_BFLOAT16 +HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) -template -HWY_API VBF16 Load(Simd d, - const bfloat16_t* HWY_RESTRICT p) { - return Load(RebindToUnsigned(), - reinterpret_cast(p)); +template +HWY_API VFromD LoadDupFull128(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadDupFull128(du, detail::U16LanePointer(p))); } -#endif // !HWY_SVE_HAVE_BFLOAT16 +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 #if HWY_TARGET == HWY_SVE2_128 -// On the HWY_SVE2_128 target, LoadDup128 is the same as Load since vectors +// On the HWY_SVE2_128 target, LoadDup128 is the same as LoadU since vectors // cannot exceed 16 bytes on the HWY_SVE2_128 target. template HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { - return Load(d, p); + return LoadU(d, p); } #else // HWY_TARGET != HWY_SVE2_128 -// If D().MaxBytes() <= 16 is true, simply do a Load operation. +// If D().MaxBytes() <= 16 is true, simply do a LoadU operation. template HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { - return Load(d, p); + return LoadU(d, p); } // If D().MaxBytes() > 16 is true, need to load the vector using ld1rq -template , bfloat16_t>()>* = nullptr> +template HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { return detail::LoadDupFull128(d, p); } -#if !HWY_SVE_HAVE_BFLOAT16 - -template -HWY_API VBF16 LoadDup128(D d, const bfloat16_t* HWY_RESTRICT p) { - return detail::LoadDupFull128( - RebindToUnsigned(), - reinterpret_cast(p)); -} -#endif // !HWY_SVE_HAVE_BFLOAT16 - #endif // HWY_TARGET != HWY_SVE2_128 -#if !HWY_SVE_HAVE_BFLOAT16 - -template -HWY_API void Store(VBF16 v, Simd d, - bfloat16_t* HWY_RESTRICT p) { - Store(v, RebindToUnsigned(), - reinterpret_cast(p)); -} - -#endif - -// ------------------------------ Load/StoreU +// ------------------------------ Load/Store // SVE only requires lane alignment, not natural alignment of the entire -// vector. +// vector, so Load/Store are the same as LoadU/StoreU. template -HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { - return Load(d, p); +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); } template -HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { - Store(v, d, p); +HWY_API void Store(const V v, D d, TFromD* HWY_RESTRICT p) { + StoreU(v, d, p); } // ------------------------------ MaskedLoadOr @@ -1362,8 +1948,8 @@ HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ - HWY_SVE_V(int, BITS) index) { \ - sv##OP##_s##BITS##index_##CHAR##BITS(m, base, index, v); \ + HWY_SVE_V(int, BITS) indices) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices, v); \ } HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) @@ -1398,10 +1984,13 @@ HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT p, #define HWY_SVE_MASKED_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ template \ HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ - HWY_SVE_V(int, BITS) index) { \ - return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, index); \ + HWY_SVE_V(int, BITS) indices) { \ + const RebindToSigned di; \ + (void)di; /* for HWY_DASSERT */ \ + HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ + return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices); \ } HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) @@ -1410,6 +1999,13 @@ HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_GATHER_INDEX, MaskedGatherIndex, #undef HWY_SVE_GATHER_OFFSET #undef HWY_SVE_MASKED_GATHER_INDEX +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, svbool_t m, D d, + const TFromD* HWY_RESTRICT p, + VFromD> indices) { + return IfThenElse(m, MaskedGatherIndex(m, d, p, indices), no); +} + template HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT p, VFromD> indices) { @@ -1430,12 +2026,13 @@ HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT p, HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ - const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ v0 = svget2(tuple, 0); \ v1 = svget2(tuple, 1); \ } HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD2, LoadInterleaved2, ld2) #undef HWY_SVE_LOAD2 @@ -1447,13 +2044,14 @@ HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ HWY_SVE_V(BASE, BITS) & v2) { \ - const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ v0 = svget3(tuple, 0); \ v1 = svget3(tuple, 1); \ v2 = svget3(tuple, 2); \ } HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD3, LoadInterleaved3, ld3) #undef HWY_SVE_LOAD3 @@ -1465,27 +2063,31 @@ HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ - const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ v0 = svget4(tuple, 0); \ v1 = svget4(tuple, 1); \ v2 = svget4(tuple, 2); \ v3 = svget4(tuple, 3); \ } HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD4, LoadInterleaved4, ld4) #undef HWY_SVE_LOAD4 // ------------------------------ StoreInterleaved2 -#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ - HWY_SVE_D(BASE, BITS, N, kPow2) d, \ - HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, Create2(d, v0, v1)); \ +#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create2(d, v0, v1)); \ } HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE2, StoreInterleaved2, st2) #undef HWY_SVE_STORE2 @@ -1497,10 +2099,12 @@ HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) HWY_SVE_V(BASE, BITS) v2, \ HWY_SVE_D(BASE, BITS, N, kPow2) d, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ Create3(d, v0, v1, v2)); \ } HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE3, StoreInterleaved3, st3) #undef HWY_SVE_STORE3 @@ -1512,13 +2116,18 @@ HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ HWY_SVE_D(BASE, BITS, N, kPow2) d, \ HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ - sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ Create4(d, v0, v1, v2, v3)); \ } HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE4, StoreInterleaved4, st4) #undef HWY_SVE_STORE4 +// Fall back on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_SVE_IF_EMULATED_D. + // ================================================== CONVERT // ------------------------------ PromoteTo @@ -1602,6 +2211,22 @@ HWY_API svfloat32_t PromoteTo(Simd /* d */, return svcvt_f32_f16_x(detail::PTrue(Simd()), vv); } +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f16_x(detail::PTrue(Simd()), + detail::ZipLowerSame(vv, vv)); +} + template HWY_API svfloat64_t PromoteTo(Simd /* d */, const svfloat32_t v) { @@ -1637,19 +2262,43 @@ HWY_API svuint64_t PromoteTo(Simd /* d */, return svcvt_u64_f32_x(detail::PTrue(Simd()), vv); } -// For 16-bit Compress +// ------------------------------ PromoteUpperTo + namespace detail { +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) #undef HWY_SVE_PROMOTE_TO +} // namespace detail -template -HWY_API svfloat32_t PromoteUpperTo(Simd df, svfloat16_t v) { - const RebindToUnsigned du; - const RepartitionToNarrow dn; - return BitCast(df, PromoteUpperTo(du, BitCast(dn, v))); +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned->Unsigned or Signed->Signed +template , typename TV = TFromV, + hwy::EnableIf() && IsInteger() && + (IsSigned() == IsSigned())>* = nullptr> +HWY_API VFromD PromoteUpperTo(D d, V v) { + if (detail::IsFull(d)) { + return detail::PromoteUpperTo(d, v); + } + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); } -} // namespace detail +// Differing signs or either is float +template , typename TV = TFromV, + hwy::EnableIf() || !IsInteger() || + (IsSigned() != IsSigned())>* = nullptr> +HWY_API VFromD PromoteUpperTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} // ------------------------------ DemoteTo U @@ -1959,6 +2608,29 @@ HWY_API svuint8_t DemoteTo(Simd dn, const svuint64_t v) { return TruncateTo(dn, vn); } +// ------------------------------ Unsigned to signed demotions + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h on SVE/SVE2 as the SVE/SVE2 targets have +// target-specific implementations of the unsigned to signed DemoteTo and +// ReorderDemote2To ops + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template ) - 1)> +HWY_API VFromD DemoteTo(D dn, V v) { + const RebindToUnsigned dn_u; + return BitCast(dn, TruncateTo(dn_u, detail::SaturateU>(v))); +} + // ------------------------------ ConcatEven/ConcatOdd // WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the @@ -1972,10 +2644,22 @@ namespace detail { } HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, + uzp1) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, + uzp2) +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC #if defined(__ARM_FEATURE_SVE_MATMUL_FP64) HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) -#endif +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, + ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, + uzp2q) +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +#endif // defined(__ARM_FEATURE_SVE_MATMUL_FP64) #undef HWY_SVE_CONCAT_EVERY_SECOND // Used to slide up / shift whole register left; mask indicates which range @@ -1986,6 +2670,16 @@ HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) return sv##OP##_##CHAR##BITS(mask, lo, hi); \ } HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) +#if HWY_SVE_HAVE_BF16_FEATURE +HWY_SVE_FOREACH_BF16(HWY_SVE_SPLICE, Splice, splice) +#else +template )> +HWY_INLINE V Splice(V hi, V lo, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Splice(BitCast(du, hi), BitCast(du, lo), mask)); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE #undef HWY_SVE_SPLICE } // namespace detail @@ -2010,21 +2704,83 @@ HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); } -// ------------------------------ DemoteTo F +// ------------------------------ PromoteEvenTo/PromoteOddTo + +// Signed to signed PromoteEvenTo: 1 instruction instead of 2 in generic-inl.h. +// Might as well also enable unsigned to unsigned, though it is just an And. +namespace detail { +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extb) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, exth) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extw) +} // namespace detail + +#include "hwy/ops/inside-inl.h" + +// ------------------------------ DemoteTo F + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat16_t in_lo16 = svcvt_f16_f64_x(detail::PTrue(d), v); + const svfloat16_t in_even = detail::ConcatEvenFull(in_lo16, in_lo16); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +#if !HWY_SVE_HAVE_F32_TO_BF16C +namespace detail { + +// Round a F32 value to the nearest BF16 value, with the result returned as the +// rounded F32 value bitcasted to an U32 + +// RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent +// NaN F32 values from being converted to an infinity +HWY_INLINE svuint32_t RoundF32ForDemoteToBF16(svfloat32_t v) { + const DFromV df32; + const RebindToUnsigned du32; -// We already toggled HWY_NATIVE_F16C above. + const auto is_non_nan = Eq(v, v); + const auto bits32 = BitCast(du32, v); -template -HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { - const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); - return detail::ConcatEvenFull(in_even, - in_even); // lower half + const auto round_incr = + detail::AddN(detail::AndN(ShiftRight<16>(bits32), 1u), 0x7FFFu); + return MaskedAddOr(detail::OrN(bits32, 0x00400000u), is_non_nan, bits32, + round_incr); } +} // namespace detail +#endif // !HWY_SVE_HAVE_F32_TO_BF16C + template HWY_API VBF16 DemoteTo(Simd dbf16, svfloat32_t v) { - const svuint16_t in_even = BitCast(ScalableTag(), v); - return BitCast(dbf16, detail::ConcatOddFull(in_even, in_even)); // lower half +#if HWY_SVE_HAVE_F32_TO_BF16C + const VBF16 in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), v); + return detail::ConcatEvenFull(in_even, in_even); +#else + const svuint16_t in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(v)); + return BitCast(dbf16, detail::ConcatOddFull(in_odd, in_odd)); // lower half +#endif } template @@ -2065,32 +2821,31 @@ HWY_API svfloat32_t DemoteTo(Simd d, const svuint64_t v) { // ------------------------------ ConvertTo F #define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ - /* signed integers */ \ + /* Float from signed */ \ template \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ } \ - /* unsigned integers */ \ + /* Float from unsigned */ \ template \ HWY_API HWY_SVE_V(BASE, BITS) \ NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ } \ - /* Truncates (rounds toward zero). */ \ + /* Signed from float, rounding toward zero */ \ template \ HWY_API HWY_SVE_V(int, BITS) \ NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ } \ - /* Truncates to unsigned (rounds toward zero). */ \ + /* Unsigned from float, rounding toward zero */ \ template \ HWY_API HWY_SVE_V(uint, BITS) \ NAME(HWY_SVE_D(uint, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_u##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ } -// API only requires f32 but we provide f64 for use by Iota. HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) #undef HWY_SVE_CONVERT @@ -2101,22 +2856,31 @@ HWY_API VFromD NearestInt(VF v) { return ConvertTo(DI(), Round(v)); } +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + // No single instruction, round then demote. + return DemoteTo(di32, Round(v)); +} + // ------------------------------ Iota (Add, ConvertTo) -#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ - HWY_SVE_T(BASE, BITS) first) { \ - return sv##OP##_##CHAR##BITS(first, 1); \ +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, T2 first) { \ + return sv##OP##_##CHAR##BITS( \ + ConvertScalarTo(first), 1); \ } HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) #undef HWY_SVE_IOTA -template -HWY_API VFromD Iota(const D d, TFromD first) { +template +HWY_API VFromD Iota(const D d, T2 first) { const RebindToSigned di; - return detail::AddN(ConvertTo(d, Iota(di, 0)), first); + return detail::AddN(ConvertTo(d, Iota(di, 0)), + ConvertScalarTo>(first)); } // ------------------------------ InterleaveLower @@ -2147,12 +2911,10 @@ HWY_API V InterleaveLower(const V a, const V b) { // Only use zip2 if vector are a powers of two, otherwise getting the actual // "upper half" requires MaskUpperHalf. -#if HWY_TARGET == HWY_SVE2_128 namespace detail { // Unlike Highway's ZipUpper, this returns the same type. HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) } // namespace detail -#endif // Full vector: guaranteed to have at least one block template , @@ -2184,6 +2946,30 @@ HWY_API V InterleaveUpper(D d, const V a, const V b) { return InterleaveUpper(DFromV(), a, b); } +// ------------------------------ InterleaveWholeLower +#ifdef HWY_NATIVE_INTERLEAVE_WHOLE +#undef HWY_NATIVE_INTERLEAVE_WHOLE +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE +#endif + +template +HWY_API VFromD InterleaveWholeLower(D /*d*/, VFromD a, VFromD b) { + return detail::ZipLowerSame(a, b); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) { + return detail::ZipUpperSame(a, b); + } + + const Half d2; + return InterleaveWholeLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); +} + // ------------------------------ Per4LaneBlockShuffle namespace detail { @@ -2432,7 +3218,13 @@ HWY_API V UpperHalf(const DH dh, const V v) { // ================================================== REDUCE -// These return T, whereas the Highway op returns a broadcasted vector. +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// These return T, suitable for ReduceSum. namespace detail { #define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ @@ -2462,24 +3254,53 @@ HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) #undef HWY_SVE_REDUCE_ADD } // namespace detail -template -V SumOfLanes(D d, V v) { - return Set(d, detail::SumOfLanesM(detail::MakeMask(d), v)); -} +// detail::SumOfLanesM, detail::MinOfLanesM, and detail::MaxOfLanesM is more +// efficient for N=4 I8/U8 reductions on SVE than the default implementations +// of the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in +// generic_ops-inl.h +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) hwy::EnableIf* = nullptr -template -TFromV ReduceSum(D d, V v) { +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D d, VFromD v) { return detail::SumOfLanesM(detail::MakeMask(d), v); } -template -V MinOfLanes(D d, V v) { - return Set(d, detail::MinOfLanesM(detail::MakeMask(d), v)); +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + return detail::MinOfLanesM(detail::MakeMask(d), v); } -template -V MaxOfLanes(D d, V v) { - return Set(d, detail::MaxOfLanesM(detail::MakeMask(d), v)); +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + return detail::MaxOfLanesM(detail::MakeMask(d), v); +} + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); } // ================================================== SWIZZLE @@ -2510,11 +3331,15 @@ HWY_API TFromV ExtractLane(V v, size_t i) { } // ------------------------------ InsertLane (IfThenElse) -template -HWY_API V InsertLane(const V v, size_t i, TFromV t) { +template +HWY_API V InsertLane(const V v, size_t i, T t) { + static_assert(sizeof(TFromV) == sizeof(T), "Lane size mismatch"); const DFromV d; - const auto is_i = detail::EqN(Iota(d, 0), static_cast>(i)); - return IfThenElse(RebindMask(d, is_i), Set(d, t), v); + const RebindToSigned di; + using TI = TFromD; + const svbool_t is_i = detail::EqN(Iota(di, 0), static_cast(i)); + return IfThenElse(RebindMask(d, is_i), + Set(d, hwy::ConvertScalarTo>(t)), v); } // ------------------------------ DupEven @@ -2569,6 +3394,18 @@ HWY_API V OddEven(const V odd, const V even) { #endif // HWY_TARGET +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return detail::InterleaveEven(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return detail::InterleaveOdd(a, b); +} + // ------------------------------ OddEvenBlocks template HWY_API V OddEvenBlocks(const V odd, const V even) { @@ -2623,6 +3460,9 @@ HWY_API VFromD> SetTableIndices(D d, const TI* idx) { } HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE, TableLookupLanes, tbl) +#endif #undef HWY_SVE_TABLE #if HWY_SVE_HAVE_2 @@ -2634,6 +3474,10 @@ namespace detail { } HWY_SVE_FOREACH(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, tbl2) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, + tbl2) +#endif #undef HWY_SVE_TABLE } // namespace detail #endif // HWY_SVE_HAVE_2 @@ -2705,6 +3549,9 @@ namespace detail { } HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_REVERSE, ReverseFull, rev) +#endif #undef HWY_SVE_REVERSE } // namespace detail @@ -2775,14 +3622,14 @@ HWY_API VFromD Reverse2(D d, const VFromD v) { // 3210 template HWY_API VFromD Reverse4(D d, const VFromD v) { const RebindToUnsigned du; - const RepartitionToWide> du32; + const RepartitionToWideX2 du32; return BitCast(d, svrevb_u32_x(detail::PTrue(d), BitCast(du32, v))); } template HWY_API VFromD Reverse4(D d, const VFromD v) { const RebindToUnsigned du; - const RepartitionToWide> du64; + const RepartitionToWideX2 du64; return BitCast(d, svrevh_u64_x(detail::PTrue(d), BitCast(du64, v))); } @@ -2943,20 +3790,23 @@ HWY_API V BroadcastBlock(V v) { static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), "Invalid block index"); + const RebindToUnsigned du; // for bfloat16_t + using VU = VFromD; + const VU vu = BitCast(du, v); + #if HWY_TARGET == HWY_SVE_256 - return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v) - : ConcatUpperUpper(d, v, v); + return BitCast(d, (kBlockIdx == 0) ? ConcatLowerLower(du, vu, vu) + : ConcatUpperUpper(du, vu, vu)); #else - const RebindToUnsigned du; using TU = TFromD; constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); constexpr size_t kBlockOffset = static_cast(kBlockIdx) * kLanesPerBlock; - const auto idx = detail::AddN( + const VU idx = detail::AddN( detail::AndN(Iota(du, TU{0}), static_cast(kLanesPerBlock - 1)), static_cast(kBlockOffset)); - return TableLookupLanes(v, idx); + return BitCast(d, TableLookupLanes(vu, idx)); #endif } @@ -3455,6 +4305,95 @@ HWY_API VFromD ZipUpper(DW dw, V a, V b) { // ================================================== Ops with dependencies +// ------------------------------ AddSub (Reverse2) + +// NOTE: svcadd_f*_x(HWY_SVE_PTRUE(BITS), a, b, 90) computes a[i] - b[i + 1] in +// the even lanes and a[i] + b[i - 1] in the odd lanes. + +#define HWY_SVE_ADDSUB_F(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, Reverse2(d, b), \ + 90); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_ADDSUB_F, AddSub, cadd) + +#undef HWY_SVE_ADDSUB_F + +// NOTE: svcadd_s*(a, b, 90) and svcadd_u*(a, b, 90) compute a[i] - b[i + 1] in +// the even lanes and a[i] + b[i - 1] in the odd lanes. + +#if HWY_SVE_HAVE_2 +#define HWY_SVE_ADDSUB_UI(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS(a, Reverse2(d, b), 90); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ADDSUB_UI, AddSub, cadd) + +#undef HWY_SVE_ADDSUB_UI + +// Disable the default implementation of AddSub in generic_ops-inl.h on SVE2 +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + hwy::EnableIf()>* = nullptr + +#else // !HWY_SVE_HAVE_2 + +// Disable the default implementation of AddSub in generic_ops-inl.h for +// floating-point vectors on SVE, but enable the default implementation of +// AddSub in generic_ops-inl.h for integer vectors on SVE that do not support +// SVE2 +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ MulAddSub (AddSub) + +template , 1), HWY_IF_FLOAT_V(V)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + using T = TFromV; + + const DFromV d; + const T neg_zero = ConvertScalarTo(-0.0f); + + return MulAdd(mul, x, AddSub(Set(d, neg_zero), sub_or_add)); +} + +#if HWY_SVE_HAVE_2 + +// Disable the default implementation of MulAddSub in generic_ops-inl.h on SVE2 +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + hwy::EnableIf()>* = nullptr + +template , 1), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + const DFromV d; + return MulAdd(mul, x, AddSub(Zero(d), sub_or_add)); +} + +#else // !HWY_SVE_HAVE_2 + +// Disable the default implementation of MulAddSub in generic_ops-inl.h for +// floating-point vectors on SVE, but enable the default implementation of +// AddSub in generic_ops-inl.h for integer vectors on SVE targets that do not +// support SVE2 +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) + +#endif // HWY_SVE_HAVE_2 + // ------------------------------ PromoteTo bfloat16 (ZipLower) template HWY_API svfloat32_t PromoteTo(Simd df32, VBF16 v) { @@ -3462,15 +4401,142 @@ HWY_API svfloat32_t PromoteTo(Simd df32, VBF16 v) { return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), BitCast(du16, v))); } +// ------------------------------ PromoteEvenTo/PromoteOddTo (ConcatOddFull) + +namespace detail { + +// Signed to signed PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<2> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint8_t v) { + return svextb_s16_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint16_t v) { + return svexth_s32_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint32_t v) { + return svextw_s64_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +// F16->F32 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat16_t v) { + const Repartition d_from; + return svcvt_f32_f16_x(detail::PTrue(d_from), v); +} + +// F32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_f64_f32_x(detail::PTrue(d_from), v); +} + +// I32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint32_t v) { + const Repartition d_from; + return svcvt_f64_s32_x(detail::PTrue(d_from), v); +} + +// U32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, + svuint32_t v) { + const Repartition d_from; + return svcvt_f64_u32_x(detail::PTrue(d_from), v); +} + +// F32->I64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_s64_f32_x(detail::PTrue(d_from), v); +} + +// F32->U64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_u64_f32_x(detail::PTrue(d_from), v); +} + +// F16->F32 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag to_type_tag, + hwy::SizeTag<4> to_lane_size_tag, + hwy::FloatTag from_type_tag, D d_to, + svfloat16_t v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +// I32/U32/F32->F64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag to_type_tag, + hwy::SizeTag<8> to_lane_size_tag, + FromTypeTag from_type_tag, D d_to, V v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +// F32->I64/U64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(ToTypeTag to_type_tag, + hwy::SizeTag<8> to_lane_size_tag, + hwy::FloatTag from_type_tag, D d_to, + svfloat32_t v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +} // namespace detail + // ------------------------------ ReorderDemote2To (OddEven) template HWY_API VBF16 ReorderDemote2To(Simd dbf16, svfloat32_t a, svfloat32_t b) { - const RebindToUnsigned du16; - const Repartition du32; - const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b)); - return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +#if HWY_SVE_HAVE_F32_TO_BF16C + const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); + return svcvtnt_bf16_f32_x(b_in_even, detail::PTrue(dbf16), a); +#else + (void)dbf16; + const auto a_in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(a)); + const auto b_in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(b)); + return BitCast(dbf16, detail::InterleaveOdd(b_in_odd, a_in_odd)); +#endif } template @@ -3608,6 +4674,14 @@ HWY_API svuint32_t ReorderDemote2To(Simd d32, svuint64_t a, #endif } +template ) / 2)> +HWY_API VFromD ReorderDemote2To(D dn, V a, V b) { + const auto clamped_a = BitCast(dn, detail::SaturateU>(a)); + const auto clamped_b = BitCast(dn, detail::SaturateU>(b)); + return detail::InterleaveEven(clamped_a, clamped_b); +} + template ), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> @@ -3618,21 +4692,55 @@ HWY_API VFromD OrderedDemote2To(D dn, V a, V b) { return Combine(dn, demoted_b, demoted_a); } -template -HWY_API VBF16 OrderedDemote2To(D dn, svfloat32_t a, svfloat32_t b) { - const Half dnh; - const RebindToUnsigned dn_u; - const RebindToUnsigned dnh_u; - const auto demoted_a = DemoteTo(dnh, a); - const auto demoted_b = DemoteTo(dnh, b); - return BitCast( - dn, Combine(dn_u, BitCast(dnh_u, demoted_b), BitCast(dnh_u, demoted_a))); +template +HWY_API VBF16 OrderedDemote2To(Simd dbf16, svfloat32_t a, + svfloat32_t b) { +#if HWY_SVE_HAVE_F32_TO_BF16C + (void)dbf16; + const VBF16 a_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), a); + const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); + return ConcatEven(dbf16, b_in_even, a_in_even); +#else + const RebindToUnsigned du16; + const svuint16_t a_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); + const svuint16_t b_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); + return BitCast(dbf16, ConcatOdd(du16, b_in_odd, a_in_odd)); // lower half +#endif +} + +// ------------------------------ I8/U8/I16/U16 Div + +template +HWY_API V Div(V a, V b) { + const DFromV d; + const Half dh; + const RepartitionToWide dw; + + const auto q_lo = + Div(PromoteTo(dw, LowerHalf(dh, a)), PromoteTo(dw, LowerHalf(dh, b))); + const auto q_hi = Div(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)); + + return OrderedDemote2To(d, q_lo, q_hi); +} + +// ------------------------------ I8/U8/I16/U16 MaskedDivOr +template +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, Div(a, b), no); } -// ------------------------------ ZeroIfNegative (Lt, IfThenElse) +// ------------------------------ Mod (Div, NegMulAdd) template -HWY_API V ZeroIfNegative(const V v) { - return IfThenZeroElse(detail::LtN(v, 0), v); +HWY_API V Mod(V a, V b) { + return NegMulAdd(Div(a, b), b, a); +} + +// ------------------------------ MaskedModOr (Mod) +template +HWY_API V MaskedModOr(V no, M m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); } // ------------------------------ BroadcastSignBit (ShiftRight) @@ -3645,22 +4753,30 @@ HWY_API V BroadcastSignBit(const V v) { template HWY_API V IfNegativeThenElse(V v, V yes, V no) { static_assert(IsSigned>(), "Only works for signed/float"); - const DFromV d; - const RebindToSigned di; - - const svbool_t m = detail::LtN(BitCast(di, v), 0); - return IfThenElse(m, yes, no); + return IfThenElse(IsNegative(v), yes, no); } // ------------------------------ AverageRound (ShiftRight) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + #if HWY_SVE_HAVE_2 -HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) -HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) #else -template -V AverageRound(const V a, const V b) { - return ShiftRight<1>(detail::AddN(Add(a, b), 1)); +template +HWY_API V AverageRound(const V a, const V b) { + return Add(Add(ShiftRight<1>(a), ShiftRight<1>(b)), + detail::AndN(Or(a, b), 1)); } #endif // HWY_SVE_HAVE_2 @@ -3735,6 +4851,84 @@ HWY_INLINE svbool_t LoadMaskBits(D /* tag */, return TestBit(vbits, bit); } +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + // Replicate the lower 8 bits of mask_bits to each u8 lane + const svuint8_t bytes = BitCast(du, Set(du, static_cast(mask_bits))); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(bytes, bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du16; + + // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, + // and then bitcast the replicated mask_bits to a u8 vector + const svuint8_t bytes = + BitCast(du, Set(du16, static_cast(mask_bits))); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, ShiftRight<3>(Iota(du, 0))); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, bytes), bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + constexpr size_t kN = MaxLanes(d); + if (kN < 4) mask_bits &= (1u << kN) - 1; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint32_t bit = svdupq_n_u32(1, 2, 4, 8); + return TestBit(BitCast(du, bytes), bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + if (MaxLanes(d) < 2) mask_bits &= 1u; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint64_t bit = svdupq_n_u64(1, 2); + return TestBit(BitCast(du, bytes), bit); +} + // ------------------------------ StoreMaskBits namespace detail { @@ -4100,12 +5294,13 @@ HWY_INLINE VFromD LaneIndicesFromByteIndices(D, svuint8_t idx) { template HWY_INLINE V ExpandLoop(V v, svbool_t mask) { const DFromV d; + using T = TFromV; uint8_t mask_bytes[256 / 8]; StoreMaskBits(d, mask, mask_bytes); // ShiftLeftLanes is expensive, so we're probably better off storing to memory // and loading the final result. - alignas(16) TFromV out[2 * MaxLanes(d)]; + alignas(16) T out[2 * MaxLanes(d)]; svbool_t next = svpfalse_b(); size_t input_consumed = 0; @@ -4117,7 +5312,7 @@ HWY_INLINE V ExpandLoop(V v, svbool_t mask) { // instruction for variable-shift-reg, but we can splice. const V vH = detail::Splice(v, v, next); input_consumed += PopCount(mask_bits); - next = detail::GeN(iota, static_cast>(input_consumed)); + next = detail::GeN(iota, ConvertScalarTo(input_consumed)); const auto idx = detail::LaneIndicesFromByteIndices( d, detail::IndicesForExpandFromBits(mask_bits)); @@ -4594,12 +5789,24 @@ HWY_API VFromD MulOdd(const V a, const V b) { #endif } +HWY_API svint64_t MulEven(const svint64_t a, const svint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { const auto lo = Mul(a, b); const auto hi = MulHigh(a, b); return detail::InterleaveEven(lo, hi); } +HWY_API svint64_t MulOdd(const svint64_t a, const svint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { const auto lo = Mul(a, b); const auto hi = MulHigh(a, b); @@ -4609,24 +5816,15 @@ HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { // ------------------------------ WidenMulPairwiseAdd template -HWY_API svfloat32_t WidenMulPairwiseAdd(Simd df32, VBF16 a, +HWY_API svfloat32_t WidenMulPairwiseAdd(Simd df, VBF16 a, VBF16 b) { -#if HWY_SVE_HAVE_BFLOAT16 - const svfloat32_t even = svbfmlalb_f32(Zero(df32), a, b); +#if HWY_SVE_HAVE_F32_TO_BF16C + const svfloat32_t even = svbfmlalb_f32(Zero(df), a, b); return svbfmlalt_f32(even, a, b); #else - const RebindToUnsigned du32; - // Using shift/and instead of Zip leads to the odd/even order that - // RearrangeToOddPlusEven prefers. - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), - Mul(BitCast(df32, ao), BitCast(df32, bo))); -#endif // HWY_SVE_HAVE_BFLOAT16 + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +#endif // HWY_SVE_HAVE_BF16_FEATURE } template @@ -4636,14 +5834,8 @@ HWY_API svint32_t WidenMulPairwiseAdd(Simd d32, svint16_t a, (void)d32; return svmlalt_s32(svmullb_s32(a, b), a, b); #else - const svbool_t pg = detail::PTrue(d32); - // Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers. - // Fortunately SVE has sign-extension for the even lanes. - const svint32_t ae = svexth_s32_x(pg, BitCast(d32, a)); - const svint32_t be = svexth_s32_x(pg, BitCast(d32, b)); - const svint32_t ao = ShiftRight<16>(BitCast(d32, a)); - const svint32_t bo = ShiftRight<16>(BitCast(d32, b)); - return svmla_s32_x(pg, svmul_s32_x(pg, ao, bo), ae, be); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); #endif } @@ -4654,43 +5846,59 @@ HWY_API svuint32_t WidenMulPairwiseAdd(Simd d32, (void)d32; return svmlalt_u32(svmullb_u32(a, b), a, b); #else - const svbool_t pg = detail::PTrue(d32); - // Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers. - // Fortunately SVE has sign-extension for the even lanes. - const svuint32_t ae = svexth_u32_x(pg, BitCast(d32, a)); - const svuint32_t be = svexth_u32_x(pg, BitCast(d32, b)); - const svuint32_t ao = ShiftRight<16>(BitCast(d32, a)); - const svuint32_t bo = ShiftRight<16>(BitCast(d32, b)); - return svmla_u32_x(pg, svmul_u32_x(pg, ao, bo), ae, be); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +#endif +} + +// ------------------------------ SatWidenMulAccumFixedPoint + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT #endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return svqdmlalb_s32(sum, detail::ZipLowerSame(a, a), + detail::ZipLowerSame(b, b)); } +#endif // HWY_SVE_HAVE_2 + // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) -template -HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd df32, - VBF16 a, VBF16 b, - const svfloat32_t sum0, - svfloat32_t& sum1) { -#if HWY_SVE_HAVE_BFLOAT16 - (void)df32; - sum1 = svbfmlalt_f32(sum1, a, b); - return svbfmlalb_f32(sum0, a, b); +#if HWY_SVE_HAVE_BF16_FEATURE + +// NOTE: we currently do not use SVE BFDOT for bf16 ReorderWidenMulAccumulate +// because, apparently unlike NEON, it uses round to odd unless the additional +// FEAT_EBF16 feature is available and enabled. +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 #else - const RebindToUnsigned du32; - // Using shift/and instead of Zip leads to the odd/even order that - // RearrangeToOddPlusEven prefers. - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); -#endif // HWY_SVE_HAVE_BFLOAT16 +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +template +HWY_API svfloat32_t MulEvenAdd(Simd /* d */, VBF16 a, VBF16 b, + const svfloat32_t c) { + return svbfmlalb_f32(c, a, b); +} + +template +HWY_API svfloat32_t MulOddAdd(Simd /* d */, VBF16 a, VBF16 b, + const svfloat32_t c) { + return svbfmlalt_f32(c, a, b); } +#endif // HWY_SVE_HAVE_BF16_FEATURE + template HWY_API svint32_t ReorderWidenMulAccumulate(Simd d32, svint16_t a, svint16_t b, @@ -4701,15 +5909,10 @@ HWY_API svint32_t ReorderWidenMulAccumulate(Simd d32, sum1 = svmlalt_s32(sum1, a, b); return svmlalb_s32(sum0, a, b); #else - const svbool_t pg = detail::PTrue(d32); - // Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers. - // Fortunately SVE has sign-extension for the even lanes. - const svint32_t ae = svexth_s32_x(pg, BitCast(d32, a)); - const svint32_t be = svexth_s32_x(pg, BitCast(d32, b)); - const svint32_t ao = ShiftRight<16>(BitCast(d32, a)); - const svint32_t bo = ShiftRight<16>(BitCast(d32, b)); - sum1 = svmla_s32_x(pg, sum1, ao, bo); - return svmla_s32_x(pg, sum0, ae, be); + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); #endif } @@ -4723,15 +5926,10 @@ HWY_API svuint32_t ReorderWidenMulAccumulate(Simd d32, sum1 = svmlalt_u32(sum1, a, b); return svmlalb_u32(sum0, a, b); #else - const svbool_t pg = detail::PTrue(d32); - // Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers. - // Fortunately SVE has sign-extension for the even lanes. - const svuint32_t ae = svexth_u32_x(pg, BitCast(d32, a)); - const svuint32_t be = svexth_u32_x(pg, BitCast(d32, b)); - const svuint32_t ao = ShiftRight<16>(BitCast(d32, a)); - const svuint32_t bo = ShiftRight<16>(BitCast(d32, b)); - sum1 = svmla_u32_x(pg, sum1, ao, bo); - return svmla_u32_x(pg, sum0, ae, be); + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); #endif } @@ -4817,8 +6015,10 @@ HWY_API VFromD SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a, // ------------------------------ AESRound / CLMul +// Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a +// baseline, in which case we check for AES support at runtime. #if defined(__ARM_FEATURE_SVE2_AES) || \ - (HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH) + (HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH && HWY_BASELINE_SVE2 == 0) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES @@ -5059,14 +6259,15 @@ HWY_API V HighestSetBitIndex(V v) { } // ================================================== END MACROS -namespace detail { // for code folding #undef HWY_SVE_ALL_PTRUE #undef HWY_SVE_D #undef HWY_SVE_FOREACH #undef HWY_SVE_FOREACH_BF16 +#undef HWY_SVE_FOREACH_BF16_UNCONDITIONAL #undef HWY_SVE_FOREACH_F #undef HWY_SVE_FOREACH_F16 #undef HWY_SVE_FOREACH_F32 +#undef HWY_SVE_FOREACH_F3264 #undef HWY_SVE_FOREACH_F64 #undef HWY_SVE_FOREACH_I #undef HWY_SVE_FOREACH_I08 @@ -5086,7 +6287,10 @@ namespace detail { // for code folding #undef HWY_SVE_FOREACH_UI64 #undef HWY_SVE_FOREACH_UIF3264 #undef HWY_SVE_HAVE_2 +#undef HWY_SVE_IF_EMULATED_D +#undef HWY_SVE_IF_NOT_EMULATED_D #undef HWY_SVE_PTRUE +#undef HWY_SVE_RETV_ARGMVV #undef HWY_SVE_RETV_ARGPV #undef HWY_SVE_RETV_ARGPVN #undef HWY_SVE_RETV_ARGPVV @@ -5098,7 +6302,6 @@ namespace detail { // for code folding #undef HWY_SVE_UNDEFINED #undef HWY_SVE_V -} // namespace detail // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/r/src/vendor/highway/hwy/ops/emu128-inl.h b/r/src/vendor/highway/hwy/ops/emu128-inl.h index 74473652..5c5ed987 100644 --- a/r/src/vendor/highway/hwy/ops/emu128-inl.h +++ b/r/src/vendor/highway/hwy/ops/emu128-inl.h @@ -16,7 +16,11 @@ // Single-element vectors and operations. // External include guard in highway.h - see comment there. -#include // std::abs, std::isnan +#include "hwy/base.h" + +#ifndef HWY_NO_LIBCXX +#include // sqrtf +#endif #include "hwy/ops/shared-inl.h" @@ -49,6 +53,9 @@ struct Vec128 { HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } @@ -97,15 +104,12 @@ HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { template using VFromD = decltype(Zero(D())); -// ------------------------------ Tuple (VFromD) -#include "hwy/ops/tuple-inl.h" - // ------------------------------ BitCast template HWY_API VFromD BitCast(D /* tag */, VFrom v) { VFromD to; - CopySameSize(&v, &to); + CopySameSize(&v.raw, &to.raw); return to; } @@ -122,7 +126,7 @@ HWY_API VFromD ResizeBitCast(D d, VFrom v) { constexpr size_t kCopyByteLen = HWY_MIN(kFromByteLen, kToByteLen); VFromD to = Zero(d); - CopyBytes(&v, &to); + CopyBytes(&v.raw, &to.raw); return to; } @@ -145,7 +149,7 @@ template HWY_API VFromD Set(D d, const T2 t) { VFromD v; for (size_t i = 0; i < MaxLanes(d); ++i) { - v.raw[i] = static_cast>(t); + v.raw[i] = ConvertScalarTo>(t); } return v; } @@ -156,14 +160,79 @@ HWY_API VFromD Undefined(D d) { return Zero(d); } +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + result.raw[4] = t4; + result.raw[5] = t5; + result.raw[6] = t6; + result.raw[7] = t7; + result.raw[8] = t8; + result.raw[9] = t9; + result.raw[10] = t10; + result.raw[11] = t11; + result.raw[12] = t12; + result.raw[13] = t13; + result.raw[14] = t14; + result.raw[15] = t15; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + result.raw[4] = t4; + result.raw[5] = t5; + result.raw[6] = t6; + result.raw[7] = t7; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + return result; +} + // ------------------------------ Iota template , typename T2> HWY_API VFromD Iota(D d, T2 first) { VFromD v; for (size_t i = 0; i < MaxLanes(d); ++i) { - v.raw[i] = - AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + v.raw[i] = AddWithWraparound(static_cast(first), i); } return v; } @@ -284,9 +353,8 @@ HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { // ------------------------------ BroadcastSignBit template HWY_API Vec128 BroadcastSignBit(Vec128 v) { - // This is used inside ShiftRight, so we cannot implement in terms of it. for (size_t i = 0; i < N; ++i) { - v.raw[i] = v.raw[i] < 0 ? T(-1) : T(0); + v.raw[i] = ScalarShr(v.raw[i], sizeof(T) * 8 - 1); } return v; } @@ -297,7 +365,7 @@ HWY_API Vec128 BroadcastSignBit(Vec128 v) { template HWY_API Mask128 MaskFromVec(Vec128 v) { Mask128 mask; - CopySameSize(&v, &mask); + CopySameSize(&v.raw, &mask.bits); return mask; } @@ -307,20 +375,15 @@ using MFromD = decltype(MaskFromVec(VFromD())); template HWY_API MFromD RebindMask(DTo /* tag */, MFrom mask) { MFromD to; - CopySameSize(&mask, &to); + CopySameSize(&mask.bits, &to.bits); return to; } -template -Vec128 VecFromMask(Mask128 mask) { - Vec128 v; - CopySameSize(&mask, &v); - return v; -} - template VFromD VecFromMask(D /* tag */, MFromD mask) { - return VecFromMask(mask); + VFromD v; + CopySameSize(&mask.bits, &v.raw); + return v; } template @@ -336,19 +399,20 @@ HWY_API MFromD FirstN(D d, size_t n) { template HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, Vec128 no) { - return IfVecThenElse(VecFromMask(mask), yes, no); + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), yes, no); } template HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { const DFromV d; - return IfVecThenElse(VecFromMask(mask), yes, Zero(d)); + return IfVecThenElse(VecFromMask(d, mask), yes, Zero(d)); } template HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { const DFromV d; - return IfVecThenElse(VecFromMask(mask), Zero(d), no); + return IfVecThenElse(VecFromMask(d, mask), Zero(d), no); } template @@ -364,17 +428,12 @@ HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, return v; } -template -HWY_API Vec128 ZeroIfNegative(Vec128 v) { - const DFromV d; - return IfNegativeThenElse(v, Zero(d), v); -} - // ------------------------------ Mask logical template HWY_API Mask128 Not(Mask128 m) { - return MaskFromVec(Not(VecFromMask(Simd(), m))); + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); } template @@ -426,41 +485,26 @@ HWY_API Vec128 ShiftLeft(Vec128 v) { template HWY_API Vec128 ShiftRight(Vec128 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); -#if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { - v.raw[i] = static_cast(v.raw[i] >> kBits); - } -#else - if (IsSigned()) { - // Emulate arithmetic shift using only logical (unsigned) shifts, because - // signed shifts are still implementation-defined. - using TU = hwy::MakeUnsigned; - for (size_t i = 0; i < N; ++i) { - const TU shifted = static_cast(static_cast(v.raw[i]) >> kBits); - const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; - const size_t sign_shift = - static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); - const TU upper = static_cast(sign << sign_shift); - v.raw[i] = static_cast(shifted | upper); - } - } else { // T is unsigned - for (size_t i = 0; i < N; ++i) { - v.raw[i] = static_cast(v.raw[i] >> kBits); - } + v.raw[i] = ScalarShr(v.raw[i], kBits); } -#endif + return v; } // ------------------------------ RotateRight (ShiftRight) -template +template HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; - return Or(ShiftRight(v), + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), ShiftLeft(v)); } @@ -477,31 +521,10 @@ HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { template HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { -#if __cplusplus >= 202002L - // Signed right shift is now guaranteed to be arithmetic (rounding toward - // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { - v.raw[i] = static_cast(v.raw[i] >> bits); - } -#else - if (IsSigned()) { - // Emulate arithmetic shift using only logical (unsigned) shifts, because - // signed shifts are still implementation-defined. - using TU = hwy::MakeUnsigned; - for (size_t i = 0; i < N; ++i) { - const TU shifted = static_cast(static_cast(v.raw[i]) >> bits); - const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; - const size_t sign_shift = - static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); - const TU upper = static_cast(sign << sign_shift); - v.raw[i] = static_cast(shifted | upper); - } - } else { - for (size_t i = 0; i < N; ++i) { - v.raw[i] = static_cast(v.raw[i] >> bits); // unsigned, logical shift - } + v.raw[i] = ScalarShr(v.raw[i], bits); } -#endif + return v; } @@ -519,32 +542,10 @@ HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { template HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { -#if __cplusplus >= 202002L - // Signed right shift is now guaranteed to be arithmetic (rounding toward - // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { - v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); - } -#else - if (IsSigned()) { - // Emulate arithmetic shift using only logical (unsigned) shifts, because - // signed shifts are still implementation-defined. - using TU = hwy::MakeUnsigned; - for (size_t i = 0; i < N; ++i) { - const TU shifted = - static_cast(static_cast(v.raw[i]) >> bits.raw[i]); - const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; - const size_t sign_shift = static_cast( - static_cast(sizeof(TU)) * 8 - 1 - bits.raw[i]); - const TU upper = static_cast(sign << sign_shift); - v.raw[i] = static_cast(shifted | upper); - } - } else { // T is unsigned - for (size_t i = 0; i < N; ++i) { - v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); - } + v.raw[i] = ScalarShr(v.raw[i], static_cast(bits.raw[i])); } -#endif + return v; } @@ -614,6 +615,15 @@ HWY_API Vec128 SumsOf8(Vec128 v) { return sums; } +template +HWY_API Vec128 SumsOf8(Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + // ------------------------------ SaturatedAdd template @@ -641,45 +651,40 @@ HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { } // ------------------------------ AverageRound -template + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +template HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { - static_assert(!IsSigned(), "Only for unsigned"); for (size_t i = 0; i < N; ++i) { - a.raw[i] = static_cast((a.raw[i] + b.raw[i] + 1) / 2); + const T a_val = a.raw[i]; + const T b_val = b.raw[i]; + a.raw[i] = static_cast(ScalarShr(a_val, 1) + ScalarShr(b_val, 1) + + ((a_val | b_val) & 1)); } return a; } // ------------------------------ Abs -// Tag dispatch instead of SFINAE for MSVC 2017 compatibility -namespace detail { - template -HWY_INLINE Vec128 Abs(SignedTag /*tag*/, Vec128 a) { +HWY_API Vec128 Abs(Vec128 a) { for (size_t i = 0; i < N; ++i) { - const T s = a.raw[i]; - const T min = hwy::LimitsMin(); - a.raw[i] = static_cast((s >= 0 || s == min) ? a.raw[i] : -s); + a.raw[i] = ScalarAbs(a.raw[i]); } return a; } -template -HWY_INLINE Vec128 Abs(hwy::FloatTag /*tag*/, Vec128 v) { - for (size_t i = 0; i < N; ++i) { - v.raw[i] = std::abs(v.raw[i]); - } - return v; -} - -} // namespace detail - -template -HWY_API Vec128 Abs(Vec128 a) { - return detail::Abs(hwy::TypeTag(), a); -} - // ------------------------------ Min/Max // Tag dispatch instead of SFINAE for MSVC 2017 compatibility @@ -706,9 +711,9 @@ template HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { - if (std::isnan(a.raw[i])) { + if (ScalarIsNaN(a.raw[i])) { a.raw[i] = b.raw[i]; - } else if (std::isnan(b.raw[i])) { + } else if (ScalarIsNaN(b.raw[i])) { // no change } else { a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); @@ -720,9 +725,9 @@ template HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { - if (std::isnan(a.raw[i])) { + if (ScalarIsNaN(a.raw[i])) { a.raw[i] = b.raw[i]; - } else if (std::isnan(b.raw[i])) { + } else if (ScalarIsNaN(b.raw[i])) { // no change } else { a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); @@ -825,7 +830,7 @@ HWY_API Vec128 operator*(Vec128 a, Vec128 b) { return detail::Mul(hwy::TypeTag(), a, b); } -template +template HWY_API Vec128 operator/(Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = (b.raw[i] == T{0}) ? 0 : a.raw[i] / b.raw[i]; @@ -833,26 +838,36 @@ HWY_API Vec128 operator/(Vec128 a, Vec128 b) { return a; } -// Returns the upper 16 bits of a * b in each lane. -template -HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + using TW = MakeWide; for (size_t i = 0; i < N; ++i) { - a.raw[i] = static_cast((int32_t{a.raw[i]} * b.raw[i]) >> 16); + a.raw[i] = static_cast( + (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> + (sizeof(T) * 8)); } return a; } -template -HWY_API Vec128 MulHigh(Vec128 a, - Vec128 b) { - for (size_t i = 0; i < N; ++i) { - // Cast to uint32_t first to prevent overflow. Otherwise the result of - // uint16_t * uint16_t is in "int" which may overflow. In practice the - // result is the same but this way it is also defined. - a.raw[i] = static_cast( - (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> - 16); - } - return a; + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(ExtractLane(a, 1), ExtractLane(b, 1), &hi_1); + + return Dup128VecFromValues(Full128(), hi_0, hi_1); } template @@ -900,7 +915,7 @@ HWY_API Vec128 ApproximateReciprocal(Vec128 v) { // Zero inputs are allowed, but callers are responsible for replacing the // return value with something else (typically using IfThenElse). This check // avoids a ubsan error. The result is arbitrary. - v.raw[i] = (std::abs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; + v.raw[i] = (ScalarAbs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; } return v; } @@ -913,25 +928,25 @@ HWY_API Vec128 AbsDiff(Vec128 a, Vec128 b) { // ------------------------------ Floating-point multiply-add variants -template +template HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, Vec128 add) { return mul * x + add; } -template +template HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, Vec128 add) { return add - mul * x; } -template +template HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, Vec128 sub) { return mul * x - sub; } -template +template HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, Vec128 sub) { return Neg(mul) * x - sub; @@ -943,21 +958,52 @@ template HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { for (size_t i = 0; i < N; ++i) { const float half = v.raw[i] * 0.5f; - uint32_t bits; - CopySameSize(&v.raw[i], &bits); // Initial guess based on log2(f) - bits = 0x5F3759DF - (bits >> 1); - CopySameSize(&bits, &v.raw[i]); + v.raw[i] = BitCastScalar(static_cast( + 0x5F3759DF - (BitCastScalar(v.raw[i]) >> 1))); // One Newton-Raphson iteration v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); } return v; } +namespace detail { + +static HWY_INLINE float ScalarSqrt(float v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return __builtin_sqrt(v); +#else + uint32_t bits = BitCastScalar(v); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1 << 29) + (bits >> 1) - (1 << 22); + return BitCastScalar(bits); +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return sqrtf(v); +#endif // !HWY_NO_LIBCXX +} +static HWY_INLINE double ScalarSqrt(double v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return __builtin_sqrt(v); +#else + uint64_t bits = BitCastScalar(v); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1ULL << 61) + (bits >> 1) - (1ULL << 51); + return BitCastScalar(bits); +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return sqrt(v); +#endif // HWY_NO_LIBCXX +} + +} // namespace detail + template HWY_API Vec128 Sqrt(Vec128 v) { for (size_t i = 0; i < N; ++i) { - v.raw[i] = std::sqrt(v.raw[i]); + v.raw[i] = detail::ScalarSqrt(v.raw[i]); } return v; } @@ -967,21 +1013,23 @@ HWY_API Vec128 Sqrt(Vec128 v) { template HWY_API Vec128 Round(Vec128 v) { using TI = MakeSigned; + const T k0 = ConvertScalarTo(0); const Vec128 a = Abs(v); for (size_t i = 0; i < N; ++i) { if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN continue; } - const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); - const TI rounded = static_cast(v.raw[i] + bias); + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); if (rounded == 0) { - v.raw[i] = v.raw[i] < 0 ? T{-0} : T{0}; + v.raw[i] = v.raw[i] < 0 ? ConvertScalarTo(-0) : k0; continue; } - const T rounded_f = static_cast(rounded); + const T rounded_f = ConvertScalarTo(rounded); // Round to even - if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { - v.raw[i] = static_cast(rounded - (v.raw[i] < T(0) ? -1 : 1)); + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + v.raw[i] = ConvertScalarTo(rounded - (v.raw[i] < k0 ? -1 : 1)); continue; } v.raw[i] = rounded_f; @@ -990,34 +1038,73 @@ HWY_API Vec128 Round(Vec128 v) { } // Round-to-nearest even. -template -HWY_API Vec128 NearestInt(Vec128 v) { - using T = float; - using TI = int32_t; +template +HWY_API Vec128, N> NearestInt(Vec128 v) { + using TI = MakeSigned; + const T k0 = ConvertScalarTo(0); - const Vec128 abs = Abs(v); - Vec128 ret; + const Vec128 abs = Abs(v); + Vec128 ret; for (size_t i = 0; i < N; ++i) { - const bool signbit = std::signbit(v.raw[i]); + const bool signbit = ScalarSignBit(v.raw[i]); if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN // Check if too large to cast or NaN - if (!(abs.raw[i] <= static_cast(LimitsMax()))) { + if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); continue; } ret.raw[i] = static_cast(v.raw[i]); continue; } - const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); - const TI rounded = static_cast(v.raw[i] + bias); + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 /*di32*/, + VFromD> v) { + using T = double; + using TI = int32_t; + const T k0 = ConvertScalarTo(0); + + constexpr size_t N = HWY_MAX_LANES_D(DI32); + + const VFromD> abs = Abs(v); + VFromD ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = ScalarSignBit(v.raw[i]); + + // Check if too large to cast or NaN + if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); if (rounded == 0) { ret.raw[i] = 0; continue; } - const T rounded_f = static_cast(rounded); + const T rounded_f = ConvertScalarTo(rounded); // Round to even - if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { ret.raw[i] = rounded - (signbit ? -1 : 1); continue; } @@ -1056,8 +1143,7 @@ Vec128 Ceil(Vec128 v) { for (size_t i = 0; i < N; ++i) { const bool positive = v.raw[i] > Float(0.0); - Bits bits; - CopySameSize(&v.raw[i], &bits); + Bits bits = BitCastScalar(v.raw[i]); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); @@ -1077,7 +1163,7 @@ Vec128 Ceil(Vec128 v) { if (positive) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; - CopySameSize(&bits, &v.raw[i]); + v.raw[i] = BitCastScalar(bits); } return v; } @@ -1094,8 +1180,7 @@ Vec128 Floor(Vec128 v) { for (size_t i = 0; i < N; ++i) { const bool negative = v.raw[i] < Float(0.0); - Bits bits; - CopySameSize(&v.raw[i], &bits); + Bits bits = BitCastScalar(v.raw[i]); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); @@ -1115,7 +1200,7 @@ Vec128 Floor(Vec128 v) { if (negative) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; - CopySameSize(&bits, &v.raw[i]); + v.raw[i] = BitCastScalar(bits); } return v; } @@ -1127,44 +1212,11 @@ HWY_API Mask128 IsNaN(Vec128 v) { Mask128 ret; for (size_t i = 0; i < N; ++i) { // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. - MakeUnsigned bits; - CopySameSize(&v.raw[i], &bits); - bits += bits; - bits >>= 1; // clear sign bit - // NaN if all exponent bits are set and the mantissa is not zero. - ret.bits[i] = Mask128::FromBool(bits > ExponentMask()); + ret.bits[i] = Mask128::FromBool(ScalarIsNaN(v.raw[i])); } return ret; } -template -HWY_API Mask128 IsInf(Vec128 v) { - static_assert(IsFloat(), "Only for float"); - const DFromV d; - const RebindToSigned di; - const VFromD vi = BitCast(di, v); - // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); -} - -// Returns whether normal/subnormal/zero. -template -HWY_API Mask128 IsFinite(Vec128 v) { - static_assert(IsFloat(), "Only for float"); - const DFromV d; - const RebindToUnsigned du; - const RebindToSigned di; // cheaper than unsigned comparison - using VI = VFromD; - using VU = VFromD; - const VU vu = BitCast(du, v); - // 'Shift left' to clear the sign bit, then right so we can compare with the - // max exponent (cannot compare with MaxExponentTimes2 directly because it is - // negative and non-negative floats would be greater). - const VI exp = - BitCast(di, ShiftRight() + 1>(Add(vu, vu))); - return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); -} - // ================================================== COMPARE template @@ -1400,93 +1452,277 @@ HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, CopyBytes(v.raw, p, num_of_lanes_to_store * sizeof(TFromD)); } -// ------------------------------ LoadInterleaved2/3/4 - -// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. -// We implement those here because scalar code is likely faster than emulation -// via shuffles. -#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED -#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED -#else -#define HWY_NATIVE_LOAD_STORE_INTERLEAVED -#endif +// ================================================== COMBINE -template > -HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1) { - alignas(16) T buf0[MaxLanes(d)]; - alignas(16) T buf1[MaxLanes(d)]; - for (size_t i = 0; i < MaxLanes(d); ++i) { - buf0[i] = *unaligned++; - buf1[i] = *unaligned++; - } - v0 = Load(d, buf0); - v1 = Load(d, buf1); +template +HWY_API Vec128 LowerHalf(Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; } -template > -HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - alignas(16) T buf0[MaxLanes(d)]; - alignas(16) T buf1[MaxLanes(d)]; - alignas(16) T buf2[MaxLanes(d)]; - for (size_t i = 0; i < MaxLanes(d); ++i) { - buf0[i] = *unaligned++; - buf1[i] = *unaligned++; - buf2[i] = *unaligned++; - } - v0 = Load(d, buf0); - v1 = Load(d, buf1); - v2 = Load(d, buf2); +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return LowerHalf(v); } -template > -HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2, - VFromD& v3) { - alignas(16) T buf0[MaxLanes(d)]; - alignas(16) T buf1[MaxLanes(d)]; - alignas(16) T buf2[MaxLanes(d)]; - alignas(16) T buf3[MaxLanes(d)]; - for (size_t i = 0; i < MaxLanes(d); ++i) { - buf0[i] = *unaligned++; - buf1[i] = *unaligned++; - buf2[i] = *unaligned++; - buf3[i] = *unaligned++; - } - v0 = Load(d, buf0); - v1 = Load(d, buf1); - v2 = Load(d, buf2); - v3 = Load(d, buf3); +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + VFromD ret; + CopyBytes(&v.raw[MaxLanes(d)], ret.raw); + return ret; } -// ------------------------------ StoreInterleaved2/3/4 - template -HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, - TFromD* HWY_RESTRICT unaligned) { - for (size_t i = 0; i < MaxLanes(d); ++i) { - *unaligned++ = v0.raw[i]; - *unaligned++ = v1.raw[i]; - } +HWY_API VFromD ZeroExtendVector(D d, VFromD> v) { + const Half dh; + VFromD ret; // zero-initialized + CopyBytes(v.raw, ret.raw); + return ret; } -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - for (size_t i = 0; i < MaxLanes(d); ++i) { - *unaligned++ = v0.raw[i]; - *unaligned++ = v1.raw[i]; - *unaligned++ = v2.raw[i]; - } +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + VFromD ret; + CopyBytes(lo_half.raw, &ret.raw[0]); + CopyBytes(hi_half.raw, &ret.raw[MaxLanes(dh)]); + return ret; } template -HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, - VFromD v3, D d, - TFromD* HWY_RESTRICT unaligned) { - for (size_t i = 0; i < MaxLanes(d); ++i) { - *unaligned++ = v0.raw[i]; +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); + CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[i] = lo.raw[2 * i]; + } + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i]; + } + return ret; +} + +// 2023-11-23: workaround for incorrect codegen (reduction_test fails for +// SumsOf2 because PromoteOddTo, which uses ConcatOdd, returns zero). +#if HWY_ARCH_RISCV && HWY_TARGET == HWY_EMU128 && HWY_COMPILER_CLANG +#define HWY_EMU128_CONCAT_INLINE HWY_NOINLINE +#else +#define HWY_EMU128_CONCAT_INLINE HWY_API +#endif + +template +HWY_EMU128_CONCAT_INLINE VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[i] = lo.raw[2 * i + 1]; + } + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i + 1]; + } + return ret; +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + VFromD ret; + const uint8_t* HWY_RESTRICT lo8 = + reinterpret_cast(lo.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(lo8 + kBytes, ret8); + CopyBytes(hi.raw, ret8 + d.MaxBytes() - kBytes); + return ret; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + VFromD ret; + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + ZeroBytes(ret8); + CopyBytes(v.raw, ret8 + kBytes); + return ret; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + VFromD ret; + const uint8_t* HWY_RESTRICT v8 = + reinterpret_cast(v.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(v8 + kBytes, ret8); + ZeroBytes(ret8 + d.MaxBytes() - kBytes); + return ret; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ Tuples, PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +// We implement those here because scalar code is likely faster than emulation +// via shuffles. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// Same for Load/StoreInterleaved of special floats. +#ifdef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + alignas(16) T buf2[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + alignas(16) T buf2[MaxLanes(d)]; + alignas(16) T buf3[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + buf3[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); + v3 = Load(d, buf3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + } +} + +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + } +} + +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; *unaligned++ = v2.raw[i]; *unaligned++ = v3.raw[i]; @@ -1510,67 +1746,100 @@ HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { namespace detail { template -HWY_INLINE ToT CastValueForF2IConv(hwy::UnsignedTag /* to_type_tag */, - FromT val) { - // Prevent ubsan errors when converting float to narrower integer - - // If LimitsMax() can be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to LimitsMax(). - - // Otherwise, if LimitsMax() cannot be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to LimitsMax() + 1, which can - // be exactly represented in FromT. - constexpr FromT kSmallestOutOfToTRangePosVal = - (sizeof(ToT) * 8 <= static_cast(MantissaBits()) + 1) - ? static_cast(LimitsMax()) - : static_cast( - static_cast(ToT{1} << (sizeof(ToT) * 8 - 1)) * FromT(2)); - - if (std::signbit(val)) { - return ToT{0}; - } else if (std::isinf(val) || val >= kSmallestOutOfToTRangePosVal) { - return LimitsMax(); - } else { - return static_cast(val); - } -} - -template -HWY_INLINE ToT CastValueForF2IConv(hwy::SignedTag /* to_type_tag */, - FromT val) { +HWY_INLINE ToT CastValueForF2IConv(FromT val) { // Prevent ubsan errors when converting float to narrower integer - // If LimitsMax() can be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to LimitsMax(). - - // Otherwise, if LimitsMax() cannot be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to -LimitsMin(), which can - // be exactly represented in FromT. - constexpr FromT kSmallestOutOfToTRangePosVal = - (sizeof(ToT) * 8 <= static_cast(MantissaBits()) + 2) - ? static_cast(LimitsMax()) - : static_cast(-static_cast(LimitsMin())); - - if (std::isinf(val) || std::fabs(val) >= kSmallestOutOfToTRangePosVal) { - return std::signbit(val) ? LimitsMin() : LimitsMax(); - } else { - return static_cast(val); - } + using FromTU = MakeUnsigned; + using ToTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(static_cast(LimitsMax()) + + static_cast(ScalarSignBit(val))); } template HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { - return static_cast(val); + return ConvertScalarTo(val); } template -HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag to_type_tag, float val) { - return CastValueForF2IConv(to_type_tag, val); +HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); } template -HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag to_type_tag, float val) { - return CastValueForF2IConv(to_type_tag, val); +HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} +// If val is within the range of ToT, CastValueForInRangeF2IConv(val) +// returns static_cast(val) +// +// Otherwise, CastValueForInRangeF2IConv(val) returns an +// implementation-defined result if val is not within the range of ToT. +template +HWY_INLINE ToT CastValueForInRangeF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(LimitsMin()); } } // namespace detail @@ -1587,6 +1856,21 @@ HWY_API VFromD PromoteTo(DTo d, Vec128 from) { return ret; } +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(D64 d64, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d64); ++i) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + // MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, // so we overload for TFrom=double and ToT={float,int32_t}. template @@ -1594,10 +1878,10 @@ HWY_API VFromD DemoteTo(D d, VFromD> from) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); ++i) { // Prevent ubsan errors when converting float to narrower integer/float - if (std::isinf(from.raw[i]) || - std::fabs(from.raw[i]) > static_cast(HighestValue())) { - ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() - : HighestValue(); + if (ScalarIsInf(from.raw[i]) || + ScalarAbs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = ScalarSignBit(from.raw[i]) ? LowestValue() + : HighestValue(); continue; } ret.raw[i] = static_cast(from.raw[i]); @@ -1609,8 +1893,7 @@ HWY_API VFromD DemoteTo(D d, VFromD> from) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); ++i) { // Prevent ubsan errors when converting double to narrower integer/int32_t - ret.raw[i] = detail::CastValueForF2IConv>( - hwy::TypeTag>(), from.raw[i]); + ret.raw[i] = detail::CastValueForF2IConv>(from.raw[i]); } return ret; } @@ -1631,17 +1914,32 @@ HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { return ret; } +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h on EMU128 as the EMU128 target has +// target-specific implementations of the unsigned to signed DemoteTo and +// ReorderDemote2To ops + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + template + HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DTo)> HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { using TTo = TFromD; static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + const auto max = static_cast>(LimitsMax()); + VFromD ret; for (size_t i = 0; i < N; ++i) { // Int to int: choose closest value in ToT to `from` (avoids UB) - from.raw[i] = HWY_MIN(from.raw[i], LimitsMax()); - ret.raw[i] = static_cast(from.raw[i]); + ret.raw[i] = static_cast(HWY_MIN(from.raw[i], max)); } return ret; } @@ -1689,14 +1987,15 @@ HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { return ret; } -template ) * 2), +template ) * 2), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const RepartitionToWide dw; const size_t NW = Lanes(dw); using TN = TFromD; - const TN max = LimitsMax(); + using TN_U = MakeUnsigned; + const TN_U max = static_cast(LimitsMax()); VFromD ret; for (size_t i = 0; i < NW; ++i) { ret.raw[i] = static_cast(HWY_MIN(a.raw[i], max)); @@ -1715,23 +2014,20 @@ HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { return ReorderDemote2To(dn, a, b); } -template ), +template ), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { - const RebindToUnsigned> du32; - const size_t NW = Lanes(du32); - VFromD> ret; - - const auto a_bits = BitCast(du32, a); - const auto b_bits = BitCast(du32, b); - + const size_t NW = Lanes(dn) / 2; + using TN = TFromD; + VFromD ret; for (size_t i = 0; i < NW; ++i) { - ret.raw[i] = static_cast(a_bits.raw[i] >> 16); + ret.raw[i] = ConvertScalarTo(a.raw[i]); } for (size_t i = 0; i < NW; ++i) { - ret.raw[NW + i] = static_cast(b_bits.raw[i] >> 16); + ret.raw[NW + i] = ConvertScalarTo(b.raw[i]); } - return BitCast(dn, ret); + return ret; } namespace detail { @@ -1758,6 +2054,12 @@ HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { return ret; } +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + template HWY_API VFromD DemoteTo(D /* tag */, Vec128 v) { VFromD ret; @@ -1767,6 +2069,21 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec128 v) { return ret; } +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D32 d32, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d32); ++i) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { @@ -1780,7 +2097,7 @@ HWY_API VFromD ConvertTo(hwy::FloatTag /*tag*/, DTo /*tag*/, for (size_t i = 0; i < N; ++i) { // float## -> int##: return closest representable value - ret.raw[i] = CastValueForF2IConv(hwy::TypeTag(), from.raw[i]); + ret.raw[i] = CastValueForF2IConv(from.raw[i]); } return ret; } @@ -1806,6 +2123,22 @@ HWY_API VFromD ConvertTo(DTo d, Vec128 from) { return detail::ConvertTo(hwy::IsFloatTag(), d, from); } +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(di); i++) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + template HWY_API Vec128 U8FromU32(Vec128 v) { return DemoteTo(Simd(), v); @@ -1893,172 +2226,6 @@ HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { return ret; } -// ================================================== COMBINE - -template -HWY_API Vec128 LowerHalf(Vec128 v) { - Vec128 ret; - CopyBytes(v.raw, ret.raw); - return ret; -} - -template -HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { - return LowerHalf(v); -} - -template -HWY_API VFromD UpperHalf(D d, VFromD> v) { - VFromD ret; - CopyBytes(&v.raw[MaxLanes(d)], ret.raw); - return ret; -} - -template -HWY_API VFromD ZeroExtendVector(D d, VFromD> v) { - const Half dh; - VFromD ret; // zero-initialized - CopyBytes(v.raw, ret.raw); - return ret; -} - -template >> -HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { - const Half dh; - VFromD ret; - CopyBytes(lo_half.raw, &ret.raw[0]); - CopyBytes(hi_half.raw, &ret.raw[MaxLanes(dh)]); - return ret; -} - -template -HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { - const Half dh; - VFromD ret; - CopyBytes(lo.raw, &ret.raw[0]); - CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); - return ret; -} - -template -HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { - const Half dh; - VFromD ret; - CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); - CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); - return ret; -} - -template -HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { - const Half dh; - VFromD ret; - CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); - CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); - return ret; -} - -template -HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { - const Half dh; - VFromD ret; - CopyBytes(lo.raw, &ret.raw[0]); - CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); - return ret; -} - -template -HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { - const Half dh; - VFromD ret; - for (size_t i = 0; i < MaxLanes(dh); ++i) { - ret.raw[i] = lo.raw[2 * i]; - } - for (size_t i = 0; i < MaxLanes(dh); ++i) { - ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i]; - } - return ret; -} - -template -HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { - const Half dh; - VFromD ret; - for (size_t i = 0; i < MaxLanes(dh); ++i) { - ret.raw[i] = lo.raw[2 * i + 1]; - } - for (size_t i = 0; i < MaxLanes(dh); ++i) { - ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i + 1]; - } - return ret; -} - -// ------------------------------ CombineShiftRightBytes -template -HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { - VFromD ret; - const uint8_t* HWY_RESTRICT lo8 = - reinterpret_cast(lo.raw); - uint8_t* HWY_RESTRICT ret8 = - reinterpret_cast(ret.raw); - CopyBytes(lo8 + kBytes, ret8); - CopyBytes(hi.raw, ret8 + d.MaxBytes() - kBytes); - return ret; -} - -// ------------------------------ ShiftLeftBytes - -template -HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { - static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); - VFromD ret; - uint8_t* HWY_RESTRICT ret8 = - reinterpret_cast(ret.raw); - ZeroBytes(ret8); - CopyBytes(v.raw, ret8 + kBytes); - return ret; -} - -template -HWY_API Vec128 ShiftLeftBytes(Vec128 v) { - return ShiftLeftBytes(DFromV(), v); -} - -// ------------------------------ ShiftLeftLanes - -template > -HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { - const Repartition d8; - return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); -} - -template -HWY_API Vec128 ShiftLeftLanes(Vec128 v) { - return ShiftLeftLanes(DFromV(), v); -} - -// ------------------------------ ShiftRightBytes -template -HWY_API VFromD ShiftRightBytes(D d, VFromD v) { - static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); - VFromD ret; - const uint8_t* HWY_RESTRICT v8 = - reinterpret_cast(v.raw); - uint8_t* HWY_RESTRICT ret8 = - reinterpret_cast(ret.raw); - CopyBytes(v8 + kBytes, ret8); - ZeroBytes(ret8 + d.MaxBytes() - kBytes); - return ret; -} - -// ------------------------------ ShiftRightLanes -template -HWY_API VFromD ShiftRightLanes(D d, VFromD v) { - const Repartition d8; - constexpr size_t kBytes = kLanes * sizeof(TFromD); - return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); -} - // ================================================== SWIZZLE template @@ -2101,6 +2268,24 @@ HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { return odd; } +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + constexpr size_t N = HWY_MAX_LANES_D(D); + for (size_t i = 1; i < N; i += 2) { + a.raw[i] = b.raw[i - 1]; + } + return a; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + constexpr size_t N = HWY_MAX_LANES_D(D); + for (size_t i = 1; i < N; i += 2) { + b.raw[i - 1] = a.raw[i]; + } + return b; +} + template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { return even; @@ -2349,8 +2534,8 @@ HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { } // Additional overload for the optional tag. -template -HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { return InterleaveLower(a, b); } @@ -2416,6 +2601,15 @@ HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { return m; } +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + m.bits[i] = MFromD::FromBool(((mask_bits >> i) & 1u) != 0); + } + return m; +} + // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { @@ -2517,7 +2711,7 @@ HWY_API Vec128 Expand(Vec128 v, const Mask128 mask) { if (mask.bits[i]) { ret.raw[i] = v.raw[in_pos++]; } else { - ret.raw[i] = T(); // zero, also works for float16_t + ret.raw[i] = ConvertScalarTo(0); } } return ret; @@ -2662,88 +2856,26 @@ HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { // ------------------------------ WidenMulPairwiseAdd -template -HWY_API VFromD WidenMulPairwiseAdd(D df32, VBF16 a, VBF16 b) { - const Rebind du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 - // Avoid ZipLower/Upper so this also works on big-endian systems. - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return Mul(BitCast(df32, ae), BitCast(df32, be)) + - Mul(BitCast(df32, ao), BitCast(df32, bo)); -} - -template -HWY_API VFromD WidenMulPairwiseAdd(D d32, VI16 a, VI16 b) { - using VI32 = VFromD; - // Manual sign extension requires two shifts for even lanes. - const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); - const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); - const VI32 ao = ShiftRight<16>(BitCast(d32, a)); - const VI32 bo = ShiftRight<16>(BitCast(d32, b)); - return Add(Mul(ae, be), Mul(ao, bo)); +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); } -template -HWY_API VFromD WidenMulPairwiseAdd(D du32, VU16 a, VU16 b) { - const auto lo16_mask = Set(du32, 0x0000FFFFu); - - const auto a0 = And(BitCast(du32, a), lo16_mask); - const auto b0 = And(BitCast(du32, b), lo16_mask); - - const auto a1 = ShiftRight<16>(BitCast(du32, a)); - const auto b1 = ShiftRight<16>(BitCast(du32, b)); - - return Add(Mul(a0, b0), Mul(a1, b1)); +template +HWY_API VFromD WidenMulPairwiseAdd(D d32, V16 a, V16 b) { + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) -template -HWY_API VFromD ReorderWidenMulAccumulate(D df32, VBF16 a, VBF16 b, - const Vec128 sum0, - Vec128& sum1) { - const Rebind du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 - // Avoid ZipLower/Upper so this also works on big-endian systems. - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); -} - -template -HWY_API VFromD ReorderWidenMulAccumulate(D d32, VI16 a, VI16 b, - const Vec128 sum0, - Vec128& sum1) { - using VI32 = VFromD; - // Manual sign extension requires two shifts for even lanes. - const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); - const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); - const VI32 ao = ShiftRight<16>(BitCast(d32, a)); - const VI32 bo = ShiftRight<16>(BitCast(d32, b)); - sum1 = Add(Mul(ao, bo), sum1); - return Add(Mul(ae, be), sum0); -} - -template -HWY_API VFromD ReorderWidenMulAccumulate(D du32, VU16 a, VU16 b, - const Vec128 sum0, - Vec128& sum1) { - using VU32 = VFromD; - const VU32 lo16_mask = Set(du32, uint32_t{0x0000FFFFu}); - const VU32 ae = And(BitCast(du32, a), lo16_mask); - const VU32 be = And(BitCast(du32, b), lo16_mask); - const VU32 ao = ShiftRight<16>(BitCast(du32, a)); - const VU32 bo = ShiftRight<16>(BitCast(du32, b)); - sum1 = Add(Mul(ao, bo), sum1); - return Add(Mul(ae, be), sum0); +template +HWY_API VFromD ReorderWidenMulAccumulate(D d32, V16 a, V16 b, + const VFromD sum0, + VFromD& sum1) { + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); } // ------------------------------ RearrangeToOddPlusEven @@ -2754,15 +2886,13 @@ HWY_API VW RearrangeToOddPlusEven(VW sum0, VW sum1) { // ================================================== REDUCTIONS -template > -HWY_API VFromD SumOfLanes(D d, VFromD v) { - T sum = T{0}; - for (size_t i = 0; i < MaxLanes(d); ++i) { - sum += v.raw[i]; - } - return Set(d, sum); -} -template > +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +template , HWY_IF_REDUCE_D(D)> HWY_API T ReduceSum(D d, VFromD v) { T sum = T{0}; for (size_t i = 0; i < MaxLanes(d); ++i) { @@ -2770,39 +2900,56 @@ HWY_API T ReduceSum(D d, VFromD v) { } return sum; } -template > -HWY_API VFromD MinOfLanes(D d, VFromD v) { +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMin(D d, VFromD v) { T min = HighestValue(); for (size_t i = 0; i < MaxLanes(d); ++i) { min = HWY_MIN(min, v.raw[i]); } - return Set(d, min); + return min; } -template > -HWY_API VFromD MaxOfLanes(D d, VFromD v) { +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMax(D d, VFromD v) { T max = LowestValue(); for (size_t i = 0; i < MaxLanes(d); ++i) { max = HWY_MAX(max, v.raw[i]); } - return Set(d, max); + return max; +} + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); } // ================================================== OPS WITH DEPENDENCIES // ------------------------------ MulEven/Odd 64x64 (UpperHalf) -HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { - alignas(16) uint64_t mul[2]; +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); - return Load(Full128(), mul); + return Load(Full128(), mul); } -HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { - alignas(16) uint64_t mul[2]; - const Half> d2; +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + const Half> d2; mul[0] = Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); - return Load(Full128(), mul); + return Load(Full128(), mul); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/r/src/vendor/highway/hwy/ops/generic_ops-inl.h b/r/src/vendor/highway/hwy/ops/generic_ops-inl.h index 9c5ac4a0..99b518d9 100644 --- a/r/src/vendor/highway/hwy/ops/generic_ops-inl.h +++ b/r/src/vendor/highway/hwy/ops/generic_ops-inl.h @@ -1,5 +1,6 @@ // Copyright 2021 Google LLC -// Copyright 2023 Arm Limited and/or its affiliates +// Copyright 2023,2024 Arm Limited and/or +// its affiliates // SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: BSD-3-Clause // @@ -17,6 +18,9 @@ // Target-independent types/functions defined after target-specific ops. +// The "include guards" in this file that check HWY_TARGET_TOGGLE serve to skip +// the generic implementation here if native ops are already defined. + #include "hwy/base.h" // Define detail::Shuffle1230 etc, but only when viewing the current header; @@ -56,7 +60,7 @@ HWY_API V Clamp(const V v, const V lo, const V hi) { // CombineShiftRightBytes (and -Lanes) are not available for the scalar target, // and RVV has its own implementation of -Lanes. -#if HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV +#if (HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV) || HWY_IDE template HWY_API VFromD CombineShiftRightLanes(D d, VFromD hi, VFromD lo) { @@ -194,6 +198,76 @@ HWY_API void SafeCopyN(const size_t num, D d, const T* HWY_RESTRICT from, #endif } +// ------------------------------ IsNegative +#if (defined(HWY_NATIVE_IS_NEGATIVE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +template +HWY_API Mask> IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MaskFromVec(BroadcastSignBit(BitCast(di, v)))); +} + +#endif // HWY_NATIVE_IS_NEGATIVE + +// ------------------------------ MaskFalse +#if (defined(HWY_NATIVE_MASK_FALSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API Mask MaskFalse(D d) { + return MaskFromVec(Zero(d)); +} + +#endif // HWY_NATIVE_MASK_FALSE + +// ------------------------------ IfNegativeThenElseZero +#if (defined(HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + return IfThenElseZero(IsNegative(v), yes); +} + +#endif // HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO + +// ------------------------------ IfNegativeThenZeroElse +#if (defined(HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + return IfThenZeroElse(IsNegative(v), no); +} + +#endif // HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE + +// ------------------------------ ZeroIfNegative (IfNegativeThenZeroElse) + +// ZeroIfNegative is generic for all vector lengths +template +HWY_API V ZeroIfNegative(V v) { + return IfNegativeThenZeroElse(v, v); +} + // ------------------------------ BitwiseIfThenElse #if (defined(HWY_NATIVE_BITWISE_IF_THEN_ELSE) == defined(HWY_TARGET_TOGGLE)) #ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE @@ -209,2596 +283,4816 @@ HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { #endif // HWY_NATIVE_BITWISE_IF_THEN_ELSE -// "Include guard": skip if native instructions are available. The generic -// implementation is currently shared between x86_* and wasm_*, and is too large -// to duplicate. +// ------------------------------ PromoteMaskTo -#if HWY_IDE || \ - (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED -#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#if (defined(HWY_NATIVE_PROMOTE_MASK_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO #else -#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#define HWY_NATIVE_PROMOTE_MASK_TO #endif -// ------------------------------ LoadInterleaved2 +template +HWY_API Mask PromoteMaskTo(DTo d_to, DFrom d_from, Mask m) { + static_assert( + sizeof(TFromD) > sizeof(TFromD), + "sizeof(TFromD) must be greater than sizeof(TFromD)"); + static_assert( + IsSame, Mask, DTo>>>(), + "Mask must be the same type as Mask, DTo>>"); -template -HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1) { - const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] - const VFromD B = LoadU(d, unaligned + Lanes(d)); - v0 = ConcatEven(d, B, A); - v1 = ConcatOdd(d, B, A); -} + const RebindToSigned di_to; + const RebindToSigned di_from; -template -HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1) { - v0 = LoadU(d, unaligned + 0); - v1 = LoadU(d, unaligned + 1); + return MaskFromVec(BitCast( + d_to, PromoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m))))); } -// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) +#endif // HWY_NATIVE_PROMOTE_MASK_TO -namespace detail { +// ------------------------------ DemoteMaskTo -#if HWY_IDE -template -HWY_INLINE V ShuffleTwo1230(V a, V /* b */) { - return a; -} -template -HWY_INLINE V ShuffleTwo2301(V a, V /* b */) { - return a; -} -template -HWY_INLINE V ShuffleTwo3012(V a, V /* b */) { - return a; -} -#endif // HWY_IDE +#if (defined(HWY_NATIVE_DEMOTE_MASK_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif -// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. -template -HWY_INLINE void LoadTransposedBlocks3(D d, - const TFromD* HWY_RESTRICT unaligned, - VFromD& A, VFromD& B, - VFromD& C) { - constexpr size_t kN = MaxLanes(d); - A = LoadU(d, unaligned + 0 * kN); - B = LoadU(d, unaligned + 1 * kN); - C = LoadU(d, unaligned + 2 * kN); -} +template +HWY_API Mask DemoteMaskTo(DTo d_to, DFrom d_from, Mask m) { + static_assert(sizeof(TFromD) < sizeof(TFromD), + "sizeof(TFromD) must be less than sizeof(TFromD)"); + static_assert( + IsSame, Mask, DTo>>>(), + "Mask must be the same type as Mask, DTo>>"); -} // namespace detail + const RebindToSigned di_to; + const RebindToSigned di_from; -template -HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - const RebindToUnsigned du; - using V = VFromD; - // Compact notation so these fit on one line: 12 := v1[2]. - V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 - V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 - V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a - detail::LoadTransposedBlocks3(d, unaligned, A, B, C); - // Compress all lanes belonging to v0 into consecutive lanes. - constexpr uint8_t Z = 0x80; - alignas(16) static constexpr uint8_t kIdx_v0A[16] = { - 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v0B[16] = { - Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v0C[16] = { - Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13}; - alignas(16) static constexpr uint8_t kIdx_v1A[16] = { - 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v1B[16] = { - Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v1C[16] = { - Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14}; - alignas(16) static constexpr uint8_t kIdx_v2A[16] = { - 2, 5, 8, 11, 14, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v2B[16] = { - Z, Z, Z, Z, Z, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v2C[16] = { - Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15}; - const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); - const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); - const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); - const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); - const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); - const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); - const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); - const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); - const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); - v0 = Xor3(v0L, v0M, v0U); - v1 = Xor3(v1L, v1M, v1U); - v2 = Xor3(v2L, v2M, v2U); + return MaskFromVec( + BitCast(d_to, DemoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m))))); } -// 8-bit lanes x8 -template -HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - const RebindToUnsigned du; - using V = VFromD; - V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] - V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] - V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] - detail::LoadTransposedBlocks3(d, unaligned, A, B, C); - // Compress all lanes belonging to v0 into consecutive lanes. - constexpr uint8_t Z = 0x80; - alignas(16) static constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, 1, 4, 7, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, 2, 5}; - alignas(16) static constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, 2, 5, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, 0, 3, 6}; - alignas(16) static constexpr uint8_t kIdx_v2A[16] = {2, 5, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v2B[16] = {Z, Z, 0, 3, 6, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, 1, 4, 7}; - const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); - const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); - const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); - const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); - const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); - const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); - const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); - const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); - const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); - v0 = Xor3(v0L, v0M, v0U); - v1 = Xor3(v1L, v1M, v1U); - v2 = Xor3(v2L, v2M, v2U); -} +#endif // HWY_NATIVE_DEMOTE_MASK_TO -// 16-bit lanes x8 -template -HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - const RebindToUnsigned du; - const Repartition du8; - using V = VFromD; - V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] - V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] - V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] - detail::LoadTransposedBlocks3(d, unaligned, A, B, C); - // Compress all lanes belonging to v0 into consecutive lanes. Same as above, - // but each element of the array contains a byte index for a byte of a lane. - constexpr uint8_t Z = 0x80; - alignas(16) static constexpr uint8_t kIdx_v0A[16] = { - 0x00, 0x01, 0x06, 0x07, 0x0C, 0x0D, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v0B[16] = { - Z, Z, Z, Z, Z, Z, 0x02, 0x03, 0x08, 0x09, 0x0E, 0x0F, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v0C[16] = { - Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0x04, 0x05, 0x0A, 0x0B}; - alignas(16) static constexpr uint8_t kIdx_v1A[16] = { - 0x02, 0x03, 0x08, 0x09, 0x0E, 0x0F, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v1B[16] = { - Z, Z, Z, Z, Z, Z, 0x04, 0x05, 0x0A, 0x0B, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v1C[16] = { - Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0x00, 0x01, 0x06, 0x07, 0x0C, 0x0D}; - alignas(16) static constexpr uint8_t kIdx_v2A[16] = { - 0x04, 0x05, 0x0A, 0x0B, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v2B[16] = { - Z, Z, Z, Z, 0x00, 0x01, 0x06, 0x07, 0x0C, 0x0D, Z, Z, Z, Z, Z, Z}; - alignas(16) static constexpr uint8_t kIdx_v2C[16] = { - Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0x02, 0x03, 0x08, 0x09, 0x0E, 0x0F}; - const V v0L = TableLookupBytesOr0(A, BitCast(d, LoadDup128(du8, kIdx_v0A))); - const V v0M = TableLookupBytesOr0(B, BitCast(d, LoadDup128(du8, kIdx_v0B))); - const V v0U = TableLookupBytesOr0(C, BitCast(d, LoadDup128(du8, kIdx_v0C))); - const V v1L = TableLookupBytesOr0(A, BitCast(d, LoadDup128(du8, kIdx_v1A))); - const V v1M = TableLookupBytesOr0(B, BitCast(d, LoadDup128(du8, kIdx_v1B))); - const V v1U = TableLookupBytesOr0(C, BitCast(d, LoadDup128(du8, kIdx_v1C))); - const V v2L = TableLookupBytesOr0(A, BitCast(d, LoadDup128(du8, kIdx_v2A))); - const V v2M = TableLookupBytesOr0(B, BitCast(d, LoadDup128(du8, kIdx_v2B))); - const V v2U = TableLookupBytesOr0(C, BitCast(d, LoadDup128(du8, kIdx_v2C))); - v0 = Xor3(v0L, v0M, v0U); - v1 = Xor3(v1L, v1M, v1U); - v2 = Xor3(v2L, v2M, v2U); +// ------------------------------ CombineMasks + +#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask CombineMasks(D d, Mask> hi, Mask> lo) { + const Half dh; + return MaskFromVec(Combine(d, VecFromMask(dh, hi), VecFromMask(dh, lo))); } +#endif -template -HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - using V = VFromD; - V A; // v0[1] v2[0] v1[0] v0[0] - V B; // v1[2] v0[2] v2[1] v1[1] - V C; // v2[3] v1[3] v0[3] v2[2] - detail::LoadTransposedBlocks3(d, unaligned, A, B, C); +#endif // HWY_NATIVE_COMBINE_MASKS - const V vxx_02_03_xx = OddEven(C, B); - v0 = detail::ShuffleTwo1230(A, vxx_02_03_xx); +// ------------------------------ LowerHalfOfMask - // Shuffle2301 takes the upper/lower halves of the output from one input, so - // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use - // OddEven because it may have higher throughput than Shuffle. - const V vxx_xx_10_11 = OddEven(A, B); - const V v12_13_xx_xx = OddEven(B, C); - v1 = detail::ShuffleTwo2301(vxx_xx_10_11, v12_13_xx_xx); +#if (defined(HWY_NATIVE_LOWER_HALF_OF_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif - const V vxx_20_21_xx = OddEven(B, A); - v2 = detail::ShuffleTwo3012(vxx_20_21_xx, C); +template +HWY_API Mask LowerHalfOfMask(D d, Mask> m) { + const Twice dt; + return MaskFromVec(LowerHalf(d, VecFromMask(dt, m))); } -template -HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - VFromD A; // v1[0] v0[0] - VFromD B; // v0[1] v2[0] - VFromD C; // v2[1] v1[1] - detail::LoadTransposedBlocks3(d, unaligned, A, B, C); - v0 = OddEven(B, A); - v1 = CombineShiftRightBytes)>(d, C, A); - v2 = OddEven(C, B); +#endif // HWY_NATIVE_LOWER_HALF_OF_MASK + +// ------------------------------ UpperHalfOfMask + +#if (defined(HWY_NATIVE_UPPER_HALF_OF_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask UpperHalfOfMask(D d, Mask> m) { + const Twice dt; + return MaskFromVec(UpperHalf(d, VecFromMask(dt, m))); } +#endif -template , HWY_IF_LANES_D(D, 1)> -HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2) { - v0 = LoadU(d, unaligned + 0); - v1 = LoadU(d, unaligned + 1); - v2 = LoadU(d, unaligned + 2); +#endif // HWY_NATIVE_UPPER_HALF_OF_MASK + +// ------------------------------ OrderedDemote2MasksTo + +#if (defined(HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask OrderedDemote2MasksTo(DTo d_to, DFrom d_from, Mask a, + Mask b) { + static_assert( + sizeof(TFromD) == sizeof(TFromD) / 2, + "sizeof(TFromD) must be equal to sizeof(TFromD) / 2"); + static_assert(IsSame, Mask, DFrom>>>(), + "Mask must be the same type as " + "Mask, DFrom>>>()"); + + const RebindToSigned di_from; + const RebindToSigned di_to; + + const auto va = BitCast(di_from, VecFromMask(d_from, a)); + const auto vb = BitCast(di_from, VecFromMask(d_from, b)); + return MaskFromVec(BitCast(d_to, OrderedDemote2To(di_to, va, vb))); } +#endif -// ------------------------------ LoadInterleaved4 +#endif // HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO -namespace detail { +// ------------------------------ RotateLeft +template +HWY_API V RotateLeft(V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); -// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. + constexpr int kRotateRightAmt = + (kBits == 0) ? 0 : static_cast(kSizeInBits) - kBits; + return RotateRight(v); +} + +// ------------------------------ InterleaveWholeLower/InterleaveWholeUpper +#if (defined(HWY_NATIVE_INTERLEAVE_WHOLE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTERLEAVE_WHOLE +#undef HWY_NATIVE_INTERLEAVE_WHOLE +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE template -HWY_INLINE void LoadTransposedBlocks4(D d, - const TFromD* HWY_RESTRICT unaligned, - VFromD& vA, VFromD& vB, - VFromD& vC, VFromD& vD) { - constexpr size_t kN = MaxLanes(d); - vA = LoadU(d, unaligned + 0 * kN); - vB = LoadU(d, unaligned + 1 * kN); - vC = LoadU(d, unaligned + 2 * kN); - vD = LoadU(d, unaligned + 3 * kN); +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + // InterleaveWholeLower(d, a, b) is equivalent to InterleaveLower(a, b) if + // D().MaxBytes() <= 16 is true + return InterleaveLower(d, a, b); +} +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // InterleaveWholeUpper(d, a, b) is equivalent to InterleaveUpper(a, b) if + // D().MaxBytes() <= 16 is true + return InterleaveUpper(d, a, b); } -} // namespace detail +// InterleaveWholeLower/InterleaveWholeUpper for 32-byte vectors on AVX2/AVX3 +// is implemented in x86_256-inl.h. -template -HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2, - VFromD& v3) { - const Repartition d64; - using V64 = VFromD; - using V = VFromD; - // 16 lanes per block; the lowest four blocks are at the bottom of vA..vD. - // Here int[i] means the four interleaved values of the i-th 4-tuple and - // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). - V vA; // int[13..10] int[3..0] - V vB; // int[17..14] int[7..4] - V vC; // int[1b..18] int[b..8] - V vD; // int[1f..1c] int[f..c] - detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); +// InterleaveWholeLower/InterleaveWholeUpper for 64-byte vectors on AVX3 is +// implemented in x86_512-inl.h. - // For brevity, the comments only list the lower block (upper = lower + 0x10) - const V v5140 = InterleaveLower(d, vA, vB); // int[5,1,4,0] - const V vd9c8 = InterleaveLower(d, vC, vD); // int[d,9,c,8] - const V v7362 = InterleaveUpper(d, vA, vB); // int[7,3,6,2] - const V vfbea = InterleaveUpper(d, vC, vD); // int[f,b,e,a] +// InterleaveWholeLower/InterleaveWholeUpper for 32-byte vectors on WASM_EMU256 +// is implemented in wasm_256-inl.h. +#endif // HWY_TARGET != HWY_SCALAR - const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] - const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] - const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] - const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] +#endif // HWY_NATIVE_INTERLEAVE_WHOLE - const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] - const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] - const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] - const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +// The InterleaveWholeLower without the optional D parameter is generic for all +// vector lengths. +template +HWY_API V InterleaveWholeLower(V a, V b) { + return InterleaveWholeLower(DFromV(), a, b); +} +#endif // HWY_TARGET != HWY_SCALAR - v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); - v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); - v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); - v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +// ------------------------------ InterleaveEven + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +// InterleaveEven without the optional D parameter is generic for all vector +// lengths +template +HWY_API V InterleaveEven(V a, V b) { + return InterleaveEven(DFromV(), a, b); } +#endif -template -HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2, - VFromD& v3) { - // In the last step, we interleave by half of the block size, which is usually - // 8 bytes but half that for 8-bit x8 vectors. - using TW = hwy::UnsignedFromSize; - const Repartition dw; - using VW = VFromD; +// ------------------------------ AddSub - // (Comments are for 256-bit vectors.) - // 8 lanes per block; the lowest four blocks are at the bottom of vA..vD. - VFromD vA; // v3210[9]v3210[8] v3210[1]v3210[0] - VFromD vB; // v3210[b]v3210[a] v3210[3]v3210[2] - VFromD vC; // v3210[d]v3210[c] v3210[5]v3210[4] - VFromD vD; // v3210[f]v3210[e] v3210[7]v3210[6] - detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); +template , 1)> +HWY_API V AddSub(V a, V b) { + // AddSub(a, b) for a one-lane vector is equivalent to Sub(a, b) + return Sub(a, b); +} - const VFromD va820 = InterleaveLower(d, vA, vB); // v3210[a,8] v3210[2,0] - const VFromD vec64 = InterleaveLower(d, vC, vD); // v3210[e,c] v3210[6,4] - const VFromD vb931 = InterleaveUpper(d, vA, vB); // v3210[b,9] v3210[3,1] - const VFromD vfd75 = InterleaveUpper(d, vC, vD); // v3210[f,d] v3210[7,5] +// AddSub for F32x2, F32x4, and F64x2 vectors is implemented in x86_128-inl.h on +// SSSE3/SSE4/AVX2/AVX3 - const VW v10_b830 = // v10[b..8] v10[3..0] - BitCast(dw, InterleaveLower(d, va820, vb931)); - const VW v10_fc74 = // v10[f..c] v10[7..4] - BitCast(dw, InterleaveLower(d, vec64, vfd75)); - const VW v32_b830 = // v32[b..8] v32[3..0] - BitCast(dw, InterleaveUpper(d, va820, vb931)); - const VW v32_fc74 = // v32[f..c] v32[7..4] - BitCast(dw, InterleaveUpper(d, vec64, vfd75)); +// AddSub for F32x8 and F64x4 vectors is implemented in x86_256-inl.h on +// AVX2/AVX3 - v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); - v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); - v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); - v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +// AddSub for F16/F32/F64 vectors on SVE is implemented in arm_sve-inl.h + +// AddSub for integer vectors on SVE2 is implemented in arm_sve-inl.h +template +HWY_API V AddSub(V a, V b) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + // Negate the even lanes of b + const auto negated_even_b = OddEven(b, BitCast(d, Neg(BitCast(d_negate, b)))); + + return Add(a, negated_even_b); } -template -HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2, - VFromD& v3) { - using V = VFromD; - V vA; // v3210[4] v3210[0] - V vB; // v3210[5] v3210[1] - V vC; // v3210[6] v3210[2] - V vD; // v3210[7] v3210[3] - detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); - const V v10e = InterleaveLower(d, vA, vC); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] - const V v10o = InterleaveLower(d, vB, vD); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] - const V v32e = InterleaveUpper(d, vA, vC); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] - const V v32o = InterleaveUpper(d, vB, vD); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] +// ------------------------------ MaskedAddOr etc. +#if (defined(HWY_NATIVE_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif - v0 = InterleaveLower(d, v10e, v10o); - v1 = InterleaveUpper(d, v10e, v10o); - v2 = InterleaveLower(d, v32e, v32o); - v3 = InterleaveUpper(d, v32e, v32o); +template +HWY_API V MaskedMinOr(V no, M m, V a, V b) { + return IfThenElse(m, Min(a, b), no); } -template -HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2, - VFromD& v3) { - VFromD vA, vB, vC, vD; - detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); - v0 = InterleaveLower(d, vA, vC); - v1 = InterleaveUpper(d, vA, vC); - v2 = InterleaveLower(d, vB, vD); - v3 = InterleaveUpper(d, vB, vD); +template +HWY_API V MaskedMaxOr(V no, M m, V a, V b) { + return IfThenElse(m, Max(a, b), no); } -// Any T x1 -template , HWY_IF_LANES_D(D, 1)> -HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, - VFromD& v0, VFromD& v1, VFromD& v2, - VFromD& v3) { - v0 = LoadU(d, unaligned + 0); - v1 = LoadU(d, unaligned + 1); - v2 = LoadU(d, unaligned + 2); - v3 = LoadU(d, unaligned + 3); +template +HWY_API V MaskedAddOr(V no, M m, V a, V b) { + return IfThenElse(m, Add(a, b), no); } -// ------------------------------ StoreInterleaved2 +template +HWY_API V MaskedSubOr(V no, M m, V a, V b) { + return IfThenElse(m, Sub(a, b), no); +} -namespace detail { +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, Mul(a, b), no); +} -// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. -template -HWY_INLINE void StoreTransposedBlocks2(VFromD A, VFromD B, D d, - TFromD* HWY_RESTRICT unaligned) { - constexpr size_t kN = MaxLanes(d); - StoreU(A, d, unaligned + 0 * kN); - StoreU(B, d, unaligned + 1 * kN); +template +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, Div(a, b), no); } -} // namespace detail +template +HWY_API V MaskedModOr(V no, M m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} -// >= 128 bit vector -template -HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, - TFromD* HWY_RESTRICT unaligned) { - const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] - const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[kN/2] v0[kN/2] - detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedAdd(a, b), no); } -// <= 64 bits -template -HWY_API void StoreInterleaved2(V part0, V part1, D d, - TFromD* HWY_RESTRICT unaligned) { - const Twice d2; - const auto v0 = ZeroExtendVector(d2, part0); - const auto v1 = ZeroExtendVector(d2, part1); - const auto v10 = InterleaveLower(d2, v0, v1); - StoreU(v10, d2, unaligned); +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedSub(a, b), no); } +#endif // HWY_NATIVE_MASKED_ARITH -// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, -// TableLookupBytes) +// ------------------------------ IfNegativeThenNegOrUndefIfZero -namespace detail { +#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif -// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. -template -HWY_INLINE void StoreTransposedBlocks3(VFromD A, VFromD B, VFromD C, - D d, TFromD* HWY_RESTRICT unaligned) { - constexpr size_t kN = MaxLanes(d); - StoreU(A, d, unaligned + 0 * kN); - StoreU(B, d, unaligned + 1 * kN); - StoreU(C, d, unaligned + 2 * kN); +template +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // MaskedSubOr is more efficient than IfNegativeThenElse on RVV/SVE + const auto zero = Zero(DFromV()); + return MaskedSubOr(v, Lt(mask, zero), zero, v); +#else + return IfNegativeThenElse(mask, Neg(v), v); +#endif } -} // namespace detail - -// >= 128-bit vector, 8-bit lanes -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - const RebindToUnsigned du; - using TU = TFromD; - const auto k5 = Set(du, TU{5}); - const auto k6 = Set(du, TU{6}); +#endif // HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG - // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): - // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes - // to their place, with 0x80 so lanes to be filled from other vectors are 0 - // to enable blending by ORing together. - alignas(16) static constexpr uint8_t tbl_v0[16] = { - 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // - 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; - alignas(16) static constexpr uint8_t tbl_v1[16] = { - 0x80, 0, 0x80, 0x80, 1, 0x80, // - 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; - // The interleaved vectors will be named A, B, C; temporaries with suffix - // 0..2 indicate which input vector's lanes they hold. - const auto shuf_A0 = LoadDup128(du, tbl_v0); - const auto shuf_A1 = LoadDup128(du, tbl_v1); // cannot reuse shuf_A0 (has 5) - const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); - const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 - const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. - const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. - const VFromD A = BitCast(d, A0 | A1 | A2); +template +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { + return CopySign(v, Xor(mask, v)); +} - // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] - const auto shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. - const auto shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 - const auto shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. - const auto B0 = TableLookupBytesOr0(v0, shuf_B0); - const auto B1 = TableLookupBytesOr0(v1, shuf_B1); - const auto B2 = TableLookupBytesOr0(v2, shuf_B2); - const VFromD B = BitCast(d, B0 | B1 | B2); +// ------------------------------ SaturatedNeg - // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] - const auto shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. - const auto shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. - const auto shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A - const auto C0 = TableLookupBytesOr0(v0, shuf_C0); - const auto C1 = TableLookupBytesOr0(v1, shuf_C1); - const auto C2 = TableLookupBytesOr0(v2, shuf_C2); - const VFromD C = BitCast(d, C0 | C1 | C2); +#if (defined(HWY_NATIVE_SATURATED_NEG_8_16_32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif - detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +template +HWY_API V SaturatedNeg(V v) { + const DFromV d; + return SaturatedSub(Zero(d), v); } -// >= 128-bit vector, 16-bit lanes -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - const Repartition du8; - const auto k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); - const auto k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); +template )> +HWY_API V SaturatedNeg(V v) { + const DFromV d; - // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): - // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be - // filled from other vectors are 0 for blending. Note that these are byte - // indices for 16-bit lanes. - alignas(16) static constexpr uint8_t tbl_v1[16] = { - 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, - 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; - alignas(16) static constexpr uint8_t tbl_v2[16] = { - 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, - 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; +#if HWY_TARGET == HWY_RVV || HWY_TARGET_IS_PPC || HWY_TARGET_IS_SVE || \ + HWY_TARGET_IS_NEON + // RVV/PPC/SVE/NEON have native I32 SaturatedSub instructions + return SaturatedSub(Zero(d), v); +#else + // ~v[i] - ((v[i] > LimitsMin()) ? -1 : 0) is equivalent to + // (v[i] > LimitsMin) ? (-v[i]) : LimitsMax() since + // -v[i] == ~v[i] + 1 == ~v[i] - (-1) and + // ~LimitsMin() == LimitsMax(). + return Sub(Not(v), VecFromMask(d, Gt(v, Set(d, LimitsMin())))); +#endif +} +#endif // HWY_NATIVE_SATURATED_NEG_8_16_32 - // The interleaved vectors will be named A, B, C; temporaries with suffix - // 0..2 indicate which input vector's lanes they hold. - const auto shuf_A1 = LoadDup128(du8, tbl_v1); // 2..1..0. - // .2..1..0 - const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); - const auto shuf_A2 = LoadDup128(du8, tbl_v2); // ..1..0.. - - const auto A0 = TableLookupBytesOr0(v0, shuf_A0); - const auto A1 = TableLookupBytesOr0(v1, shuf_A1); - const auto A2 = TableLookupBytesOr0(v2, shuf_A2); - const VFromD A = BitCast(d, A0 | A1 | A2); +#if (defined(HWY_NATIVE_SATURATED_NEG_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif - // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] - const auto shuf_B0 = shuf_A1 + k3; // 5..4..3. - const auto shuf_B1 = shuf_A2 + k3; // ..4..3.. - const auto shuf_B2 = shuf_A0 + k2; // .4..3..2 - const auto B0 = TableLookupBytesOr0(v0, shuf_B0); - const auto B1 = TableLookupBytesOr0(v1, shuf_B1); - const auto B2 = TableLookupBytesOr0(v2, shuf_B2); - const VFromD B = BitCast(d, B0 | B1 | B2); +template )> +HWY_API V SaturatedNeg(V v) { +#if HWY_TARGET == HWY_RVV || HWY_TARGET_IS_SVE || HWY_TARGET_IS_NEON + // RVV/SVE/NEON have native I64 SaturatedSub instructions + const DFromV d; + return SaturatedSub(Zero(d), v); +#else + const auto neg_v = Neg(v); + return Add(neg_v, BroadcastSignBit(And(v, neg_v))); +#endif +} +#endif // HWY_NATIVE_SATURATED_NEG_64 - // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] - const auto shuf_C0 = shuf_B1 + k3; // ..7..6.. - const auto shuf_C1 = shuf_B2 + k3; // .7..6..5 - const auto shuf_C2 = shuf_B0 + k2; // 7..6..5. - const auto C0 = TableLookupBytesOr0(v0, shuf_C0); - const auto C1 = TableLookupBytesOr0(v1, shuf_C1); - const auto C2 = TableLookupBytesOr0(v2, shuf_C2); - const VFromD C = BitCast(d, C0 | C1 | C2); +// ------------------------------ SaturatedAbs - detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +#if (defined(HWY_NATIVE_SATURATED_ABS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +template +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedNeg(v)); } -// >= 128-bit vector, 32-bit lanes -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - const RepartitionToWide dw; +#endif - const VFromD v10_v00 = InterleaveLower(d, v0, v1); - const VFromD v01_v20 = OddEven(v0, v2); - // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) - const VFromD A = BitCast( - d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); +// ------------------------------ Reductions - const VFromD v1_321 = ShiftRightLanes<1>(d, v1); - const VFromD v0_32 = ShiftRightLanes<2>(d, v0); - const VFromD v21_v11 = OddEven(v2, v1_321); - const VFromD v12_v02 = OddEven(v1_321, v0_32); - // B: v1[2],v0[2], v2[1],v1[1] - const VFromD B = BitCast( - d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); +// Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled, +// they (RVV/SVE/Armv8/Emu128) implement ReduceSum and SumOfLanes via Set. +// Otherwise, they (Armv7/PPC/scalar/WASM/x86) define zero to most of the +// SumOfLanes overloads. For the latter group, we here define the remaining +// overloads, plus ReduceSum which uses them plus GetLane. +#if (defined(HWY_NATIVE_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif - // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. - const VFromD v23_v13 = OddEven(v2, v1_321); - const VFromD v03_v22 = OddEven(v0, v2); - // C: v2[3],v1[3],v0[3], v2[2] - const VFromD C = BitCast( - d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); +namespace detail { - detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +// Allows reusing the same shuffle code for SumOfLanes/MinOfLanes/MaxOfLanes. +struct AddFunc { + template + V operator()(V a, V b) const { + return Add(a, b); + } +}; + +struct MinFunc { + template + V operator()(V a, V b) const { + return Min(a, b); + } +}; + +struct MaxFunc { + template + V operator()(V a, V b) const { + return Max(a, b); + } +}; + +// No-op for vectors of at most one block. +template +HWY_INLINE VFromD ReduceAcrossBlocks(D, Func, VFromD v) { + return v; } -// >= 128-bit vector, 64-bit lanes -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - const VFromD A = InterleaveLower(d, v0, v1); - const VFromD B = OddEven(v0, v2); - const VFromD C = InterleaveUpper(d, v1, v2); - detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +// Reduces a lane with its counterpart in other block(s). Shared by AVX2 and +// WASM_EMU256. AVX3 has its own overload. +template +HWY_INLINE VFromD ReduceAcrossBlocks(D /*d*/, Func f, VFromD v) { + return f(v, SwapAdjacentBlocks(v)); } -// 64-bit vector, 8-bit lanes -template -HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, - VFromD part2, D d, - TFromD* HWY_RESTRICT unaligned) { - // Use full vectors for the shuffles and first result. - constexpr size_t kFullN = 16 / sizeof(TFromD); - const Full128 du; - const Full128> d_full; - const auto k5 = Set(du, uint8_t{5}); - const auto k6 = Set(du, uint8_t{6}); +// These return the reduction result broadcasted across all lanes. They assume +// the caller has already reduced across blocks. - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v10) { + return f(v10, Reverse2(d, v10)); +} - // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): - // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be - // filled from other vectors are 0 for blending. - alignas(16) static constexpr uint8_t tbl_v0[16] = { - 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // - 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; - alignas(16) static constexpr uint8_t tbl_v1[16] = { - 0x80, 0, 0x80, 0x80, 1, 0x80, // - 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; - // The interleaved vectors will be named A, B, C; temporaries with suffix - // 0..2 indicate which input vector's lanes they hold. - const auto shuf_A0 = Load(du, tbl_v0); - const auto shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) - const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); - const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 - const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. - const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. - const auto A = BitCast(d_full, A0 | A1 | A2); - StoreU(A, d_full, unaligned + 0 * kFullN); +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v3210) { + const VFromD v0123 = Reverse4(d, v3210); + const VFromD v03_12_12_03 = f(v3210, v0123); + const VFromD v12_03_03_12 = Reverse2(d, v03_12_12_03); + return f(v03_12_12_03, v12_03_03_12); +} - // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] - const auto shuf_B0 = shuf_A2 + k6; // ..7..6.. - const auto shuf_B1 = shuf_A0 + k5; // .7..6..5 - const auto shuf_B2 = shuf_A1 + k5; // 7..6..5. - const auto B0 = TableLookupBytesOr0(v0, shuf_B0); - const auto B1 = TableLookupBytesOr0(v1, shuf_B1); - const auto B2 = TableLookupBytesOr0(v2, shuf_B2); - const VFromD B{BitCast(d_full, B0 | B1 | B2).raw}; - StoreU(B, d, unaligned + 1 * kFullN); +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v76543210) { + // The upper half is reversed from the lower half; omit for brevity. + const VFromD v34_25_16_07 = f(v76543210, Reverse8(d, v76543210)); + const VFromD v0347_1625_1625_0347 = + f(v34_25_16_07, Reverse4(d, v34_25_16_07)); + return f(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); } -// 64-bit vector, 16-bit lanes -template -HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, - VFromD part2, D dh, - TFromD* HWY_RESTRICT unaligned) { - const Twice d_full; - const Full128 du8; - const auto k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); - const auto k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); - - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v) { + const RepartitionToWide dw; + using VW = VFromD; + const VW vw = BitCast(dw, v); + // f is commutative, so no need to adapt for HWY_IS_LITTLE_ENDIAN. + const VW even = And(vw, Set(dw, 0xFF)); + const VW odd = ShiftRight<8>(vw); + const VW reduced = ReduceWithinBlocks(dw, f, f(even, odd)); +#if HWY_IS_LITTLE_ENDIAN + return DupEven(BitCast(d, reduced)); +#else + return DupOdd(BitCast(d, reduced)); +#endif +} - // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): - // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes - // to their place, with 0x80 so lanes to be filled from other vectors are 0 - // to enable blending by ORing together. - alignas(16) static constexpr uint8_t tbl_v1[16] = { - 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, - 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; - alignas(16) static constexpr uint8_t tbl_v2[16] = { - 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, - 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v) { + const RepartitionToWide dw; + using VW = VFromD; + const VW vw = BitCast(dw, v); + // Sign-extend + // f is commutative, so no need to adapt for HWY_IS_LITTLE_ENDIAN. + const VW even = ShiftRight<8>(ShiftLeft<8>(vw)); + const VW odd = ShiftRight<8>(vw); + const VW reduced = ReduceWithinBlocks(dw, f, f(even, odd)); +#if HWY_IS_LITTLE_ENDIAN + return DupEven(BitCast(d, reduced)); +#else + return DupOdd(BitCast(d, reduced)); +#endif +} - // The interleaved vectors will be named A, B; temporaries with suffix - // 0..2 indicate which input vector's lanes they hold. - const auto shuf_A1 = Load(du8, tbl_v1); // 2..1..0. - // .2..1..0 - const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); - const auto shuf_A2 = Load(du8, tbl_v2); // ..1..0.. - - const auto A0 = TableLookupBytesOr0(v0, shuf_A0); - const auto A1 = TableLookupBytesOr0(v1, shuf_A1); - const auto A2 = TableLookupBytesOr0(v2, shuf_A2); - const VFromD A = BitCast(d_full, A0 | A1 | A2); - StoreU(A, d_full, unaligned); +} // namespace detail - // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] - const auto shuf_B0 = shuf_A1 + k3; // ..3. - const auto shuf_B1 = shuf_A2 + k3; // .3.. - const auto shuf_B2 = shuf_A0 + k2; // 3..2 - const auto B0 = TableLookupBytesOr0(v0, shuf_B0); - const auto B1 = TableLookupBytesOr0(v1, shuf_B1); - const auto B2 = TableLookupBytesOr0(v2, shuf_B2); - const VFromD B = BitCast(d_full, B0 | B1 | B2); - StoreU(VFromD{B.raw}, dh, unaligned + MaxLanes(d_full)); +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const detail::AddFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); } - -// 64-bit vector, 32-bit lanes -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - // (same code as 128-bit vector, 64-bit lanes) - const VFromD v10_v00 = InterleaveLower(d, v0, v1); - const VFromD v01_v20 = OddEven(v0, v2); - const VFromD v21_v11 = InterleaveUpper(d, v1, v2); - constexpr size_t kN = MaxLanes(d); - StoreU(v10_v00, d, unaligned + 0 * kN); - StoreU(v01_v20, d, unaligned + 1 * kN); - StoreU(v21_v11, d, unaligned + 2 * kN); +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + const detail::MinFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + const detail::MaxFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); } -// 64-bit lanes are handled by the N=1 case below. +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + return GetLane(SumOfLanes(d, v)); +} +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + return GetLane(MinOfLanes(d, v)); +} +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + return GetLane(MaxOfLanes(d, v)); +} -// <= 32-bit vector, 8-bit lanes -template -HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, - VFromD part2, D d, - TFromD* HWY_RESTRICT unaligned) { - // Use full vectors for the shuffles and result. - const Full128 du; - const Full128> d_full; +#endif // HWY_NATIVE_REDUCE_SCALAR - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; +// Corner cases for both generic and native implementations: +// N=1 (native covers N=2 e.g. for u64x2 and even u32x2 on Arm) +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return GetLane(v); +} +template +HWY_API TFromD ReduceMin(D /*d*/, VFromD v) { + return GetLane(v); +} +template +HWY_API TFromD ReduceMax(D /*d*/, VFromD v) { + return GetLane(v); +} - // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 - // so lanes to be filled from other vectors are 0 to enable blending by ORing - // together. - alignas(16) static constexpr uint8_t tbl_v0[16] = { - 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, - 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; - // The interleaved vector will be named A; temporaries with suffix - // 0..2 indicate which input vector's lanes they hold. - const auto shuf_A0 = Load(du, tbl_v0); - const auto shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); - const auto shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); - const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 - const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. - const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. - const VFromD A = BitCast(d_full, A0 | A1 | A2); - alignas(16) TFromD buf[MaxLanes(d_full)]; - StoreU(A, d_full, buf); - CopyBytes(buf, unaligned); +template +HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { + return v; +} +template +HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { + return v; +} +template +HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { + return v; } -// 32-bit vector, 16-bit lanes -template -HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, - VFromD part2, D d, - TFromD* HWY_RESTRICT unaligned) { - // Use full vectors for the shuffles and result. - const Full128 du8; - const Full128> d_full; +// N=4 for 8-bit is still less than the minimum native size. - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; +// ARMv7 NEON/PPC/RVV/SVE have target-specific implementations of the N=4 I8/U8 +// ReduceSum operations +#if (defined(HWY_NATIVE_REDUCE_SUM_4_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceSum(dw, PromoteTo(dw, v))); +} +#endif // HWY_NATIVE_REDUCE_SUM_4_UI8 + +// RVV/SVE have target-specific implementations of the N=4 I8/U8 +// ReduceMin/ReduceMax operations +#if (defined(HWY_NATIVE_REDUCE_MINMAX_4_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceMin(dw, PromoteTo(dw, v))); +} +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceMax(dw, PromoteTo(dw, v))); +} +#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8 + +// ------------------------------ IsEitherNaN +#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif - // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 - // so lanes to be filled from other vectors are 0 to enable blending by ORing - // together. - alignas(16) static constexpr uint8_t tbl_v2[16] = { - 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, - 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; - // The interleaved vector will be named A; temporaries with suffix - // 0..2 indicate which input vector's lanes they hold. - const auto shuf_A2 = // ..1..0.. - Load(du8, tbl_v2); - const auto shuf_A1 = // ...1..0. - CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); - const auto shuf_A0 = // ....1..0 - CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); - const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 - const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. - const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. - const auto A = BitCast(d_full, A0 | A1 | A2); - alignas(16) TFromD buf[MaxLanes(d_full)]; - StoreU(A, d_full, buf); - CopyBytes(buf, unaligned); +template +HWY_API MFromD> IsEitherNaN(V a, V b) { + return Or(IsNaN(a), IsNaN(b)); } -// Single-element vector, any lane size: just store directly -template -HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, - TFromD* HWY_RESTRICT unaligned) { - StoreU(v0, d, unaligned + 0); - StoreU(v1, d, unaligned + 1); - StoreU(v2, d, unaligned + 2); -} +#endif // HWY_NATIVE_IS_EITHER_NAN -// ------------------------------ StoreInterleaved4 +// ------------------------------ IsInf, IsFinite -namespace detail { +// AVX3 has target-specific implementations of these. +#if (defined(HWY_NATIVE_ISINF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif -// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. -template -HWY_INLINE void StoreTransposedBlocks4(VFromD vA, VFromD vB, VFromD vC, - VFromD vD, D d, - TFromD* HWY_RESTRICT unaligned) { - constexpr size_t kN = MaxLanes(d); - StoreU(vA, d, unaligned + 0 * kN); - StoreU(vB, d, unaligned + 1 * kN); - StoreU(vC, d, unaligned + 2 * kN); - StoreU(vD, d, unaligned + 3 * kN); +template > +HWY_API MFromD IsInf(const V v) { + using T = TFromD; + const D d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask( + d, + Eq(Add(vu, vu), + Set(du, static_cast>(hwy::MaxExponentTimes2())))); } -} // namespace detail +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + using T = TFromD; + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); +// 'Shift left' to clear the sign bit. MSVC seems to generate incorrect code +// for AVX2 if we instead add vu + vu. +#if HWY_COMPILER_MSVC + const VFromD shl = ShiftLeft<1>(vu); +#else + const VFromD shl = Add(vu, vu); +#endif -// >= 128-bit vector, 8..32-bit lanes -template -HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, - VFromD v3, D d, - TFromD* HWY_RESTRICT unaligned) { - const RepartitionToWide dw; - const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] - const auto v32L = ZipLower(dw, v2, v3); - const auto v10U = ZipUpper(dw, v0, v1); - const auto v32U = ZipUpper(dw, v2, v3); - // The interleaved vectors are vA, vB, vC, vD. - const VFromD vA = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 - const VFromD vB = BitCast(d, InterleaveUpper(dw, v10L, v32L)); - const VFromD vC = BitCast(d, InterleaveLower(dw, v10U, v32U)); - const VFromD vD = BitCast(d, InterleaveUpper(dw, v10U, v32U)); - detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); + // Then shift right so we can compare with the max exponent (cannot compare + // with MaxExponentTimes2 directly because it is negative and non-negative + // floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(shl)); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); } -// >= 128-bit vector, 64-bit lanes -template -HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, - VFromD v3, D d, - TFromD* HWY_RESTRICT unaligned) { - // The interleaved vectors are vA, vB, vC, vD. - const VFromD vA = InterleaveLower(d, v0, v1); // v1[0] v0[0] - const VFromD vB = InterleaveLower(d, v2, v3); - const VFromD vC = InterleaveUpper(d, v0, v1); - const VFromD vD = InterleaveUpper(d, v2, v3); - detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); -} +#endif // HWY_NATIVE_ISINF -// 64-bit vector, 8..32-bit lanes -template -HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, - VFromD part2, VFromD part3, D /* tag */, - TFromD* HWY_RESTRICT unaligned) { - // Use full vectors to reduce the number of stores. - const Full128> d_full; - const RepartitionToWide dw; - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; - const VFromD v3{part3.raw}; - const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] - const auto v32 = ZipLower(dw, v2, v3); - const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); - const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); - StoreU(A, d_full, unaligned); - StoreU(B, d_full, unaligned + MaxLanes(d_full)); -} +// ------------------------------ CeilInt/FloorInt +#if (defined(HWY_NATIVE_CEIL_FLOOR_INT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif -// 64-bit vector, 64-bit lane -template -HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, - VFromD part2, VFromD part3, D /* tag */, - TFromD* HWY_RESTRICT unaligned) { - // Use full vectors to reduce the number of stores. - const Full128> d_full; - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; - const VFromD v3{part3.raw}; - const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] - const auto B = InterleaveLower(d_full, v2, v3); - StoreU(A, d_full, unaligned); - StoreU(B, d_full, unaligned + MaxLanes(d_full)); +template +HWY_API VFromD>> CeilInt(V v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Ceil(v)); } -// <= 32-bit vectors -template -HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, - VFromD part2, VFromD part3, D d, - TFromD* HWY_RESTRICT unaligned) { - // Use full vectors to reduce the number of stores. - const Full128> d_full; - const RepartitionToWide dw; - const VFromD v0{part0.raw}; - const VFromD v1{part1.raw}; - const VFromD v2{part2.raw}; - const VFromD v3{part3.raw}; - const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] - const auto v32 = ZipLower(dw, v2, v3); - const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); - alignas(16) TFromD buf[MaxLanes(d_full)]; - StoreU(v3210, d_full, buf); - CopyBytes(buf, unaligned); +template +HWY_API VFromD>> FloorInt(V v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Floor(v)); } -#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED - -// ------------------------------ LoadN +#endif // HWY_NATIVE_CEIL_FLOOR_INT -#if (defined(HWY_NATIVE_LOAD_N) == defined(HWY_TARGET_TOGGLE)) +// ------------------------------ MulByPow2/MulByFloorPow2 -#ifdef HWY_NATIVE_LOAD_N -#undef HWY_NATIVE_LOAD_N +#if (defined(HWY_NATIVE_MUL_BY_POW2) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MUL_BY_POW2 +#undef HWY_NATIVE_MUL_BY_POW2 #else -#define HWY_NATIVE_LOAD_N +#define HWY_NATIVE_MUL_BY_POW2 #endif -#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE -namespace detail { +template +HWY_API V MulByPow2(V v, VFromD>> exp) { + const DFromV df; + const RebindToUnsigned du; + const RebindToSigned di; -template -HWY_INLINE VFromD LoadNResizeBitCast(DTo d_to, DFrom d_from, - VFromD v) { -#if HWY_TARGET <= HWY_SSE2 - // On SSE2/SSSE3/SSE4, the LoadU operation will zero out any lanes of v.raw + using TF = TFromD; + using TI = TFromD; + using TU = TFromD; + + using VF = VFromD; + using VI = VFromD; + + constexpr TI kMaxBiasedExp = MaxExponentField(); + static_assert(kMaxBiasedExp > 0, "kMaxBiasedExp > 0 must be true"); + + constexpr TI kExpBias = static_cast(kMaxBiasedExp >> 1); + static_assert(kExpBias > 0, "kExpBias > 0 must be true"); + static_assert(kExpBias <= LimitsMax() / 3, + "kExpBias <= LimitsMax() / 3 must be true"); + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + using TExpMinMax = If<(sizeof(TI) <= 4), TI, int32_t>; +#elif (HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2) || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_WASM_EMU256 + using TExpMinMax = int16_t; +#else + using TExpMinMax = TI; +#endif + +#if HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SCALAR + using TExpSatSub = TU; +#elif HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 + using TExpSatSub = If<(sizeof(TF) == 4), uint8_t, uint16_t>; +#elif HWY_TARGET_IS_PPC + using TExpSatSub = If<(sizeof(TF) >= 4), uint32_t, TU>; +#else + using TExpSatSub = If<(sizeof(TF) == 4), uint8_t, TU>; +#endif + + static_assert(kExpBias <= static_cast(LimitsMax() / 3), + "kExpBias <= LimitsMax() / 3 must be true"); + + const Repartition d_exp_min_max; + const Repartition d_sat_exp_sub; + + constexpr int kNumOfExpBits = ExponentBits(); + constexpr int kNumOfMantBits = MantissaBits(); + + // The sign bit of BitCastScalar(a[i]) >> kNumOfMantBits can be zeroed out + // using SaturatedSub if kZeroOutSignUsingSatSub is true. + + // If kZeroOutSignUsingSatSub is true, then val_for_exp_sub will be bitcasted + // to a vector that has a smaller lane size than TU for the SaturatedSub + // operation below. + constexpr bool kZeroOutSignUsingSatSub = + ((sizeof(TExpSatSub) * 8) == static_cast(kNumOfExpBits)); + + // If kZeroOutSignUsingSatSub is true, then the upper + // (sizeof(TU) - sizeof(TExpSatSub)) * 8 bits of kExpDecrBy1Bits will be all + // ones and the lower sizeof(TExpSatSub) * 8 bits of kExpDecrBy1Bits will be + // equal to 1. + + // Otherwise, if kZeroOutSignUsingSatSub is false, kExpDecrBy1Bits will be + // equal to 1. + constexpr TU kExpDecrBy1Bits = static_cast( + TU{1} - (static_cast(kZeroOutSignUsingSatSub) << kNumOfExpBits)); + + VF val_for_exp_sub = v; + HWY_IF_CONSTEXPR(!kZeroOutSignUsingSatSub) { + // If kZeroOutSignUsingSatSub is not true, zero out the sign bit of + // val_for_exp_sub[i] using Abs + val_for_exp_sub = Abs(val_for_exp_sub); + } + + // min_exp1_plus_min_exp2[i] is the smallest exponent such that + // min_exp1_plus_min_exp2[i] >= 2 - kExpBias * 2 and + // std::ldexp(v[i], min_exp1_plus_min_exp2[i]) is a normal floating-point + // number if v[i] is a normal number + const VI min_exp1_plus_min_exp2 = BitCast( + di, + Max(BitCast( + d_exp_min_max, + Neg(BitCast( + di, + SaturatedSub( + BitCast(d_sat_exp_sub, ShiftRight( + BitCast(du, val_for_exp_sub))), + BitCast(d_sat_exp_sub, Set(du, kExpDecrBy1Bits)))))), + BitCast(d_exp_min_max, + Set(di, static_cast(2 - kExpBias - kExpBias))))); + + const VI clamped_exp = + Max(Min(exp, Set(di, static_cast(kExpBias * 3))), + Add(min_exp1_plus_min_exp2, Set(di, static_cast(1 - kExpBias)))); + + const VI exp1_plus_exp2 = BitCast( + di, Max(Min(BitCast(d_exp_min_max, + Sub(clamped_exp, ShiftRight<2>(clamped_exp))), + BitCast(d_exp_min_max, + Set(di, static_cast(kExpBias + kExpBias)))), + BitCast(d_exp_min_max, min_exp1_plus_min_exp2))); + + const VI exp1 = ShiftRight<1>(exp1_plus_exp2); + const VI exp2 = Sub(exp1_plus_exp2, exp1); + const VI exp3 = Sub(clamped_exp, exp1_plus_exp2); + + const VI exp_bias = Set(di, kExpBias); + + const VF factor1 = + BitCast(df, ShiftLeft(Add(exp1, exp_bias))); + const VF factor2 = + BitCast(df, ShiftLeft(Add(exp2, exp_bias))); + const VF factor3 = + BitCast(df, ShiftLeft(Add(exp3, exp_bias))); + + return Mul(Mul(Mul(v, factor1), factor2), factor3); +} + +template +HWY_API V MulByFloorPow2(V v, V exp) { + const DFromV df; + + // MulByFloorPow2 special cases: + // MulByFloorPow2(v, NaN) => NaN + // MulByFloorPow2(0, inf) => NaN + // MulByFloorPow2(inf, -inf) => NaN + // MulByFloorPow2(-inf, -inf) => NaN + const auto is_special_case_with_nan_result = + Or(IsNaN(exp), + And(Eq(Abs(v), IfNegativeThenElseZero(exp, Inf(df))), IsInf(exp))); + + return IfThenElse(is_special_case_with_nan_result, NaN(df), + MulByPow2(v, FloorInt(exp))); +} + +#endif // HWY_NATIVE_MUL_BY_POW2 + +// ------------------------------ LoadInterleaved2 + +#if HWY_IDE || \ + (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] + const VFromD B = LoadU(d, unaligned + Lanes(d)); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +template +HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) + +namespace detail { + +#if HWY_IDE +template +HWY_INLINE V ShuffleTwo1230(V a, V /* b */) { + return a; +} +template +HWY_INLINE V ShuffleTwo2301(V a, V /* b */) { + return a; +} +template +HWY_INLINE V ShuffleTwo3012(V a, V /* b */) { + return a; +} +#endif // HWY_IDE + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void LoadTransposedBlocks3(D d, + const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, + VFromD& C) { + constexpr size_t kN = MaxLanes(d); + A = LoadU(d, unaligned + 0 * kN); + B = LoadU(d, unaligned + 1 * kN); + C = LoadU(d, unaligned + 2 * kN); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + using V = VFromD; + using VU = VFromD; + // Compact notation so these fit on one line: 12 := v1[2]. + V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 + V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 + V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + const VU idx_v0A = + Dup128VecFromValues(du, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v0B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z); + const VU idx_v0C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13); + const VU idx_v1A = + Dup128VecFromValues(du, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v1B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z); + const VU idx_v1C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14); + const VU idx_v2A = + Dup128VecFromValues(du, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v2B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z); + const VU idx_v2C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15); + const V v0L = BitCast(d, TableLookupBytesOr0(A, idx_v0A)); + const V v0M = BitCast(d, TableLookupBytesOr0(B, idx_v0B)); + const V v0U = BitCast(d, TableLookupBytesOr0(C, idx_v0C)); + const V v1L = BitCast(d, TableLookupBytesOr0(A, idx_v1A)); + const V v1M = BitCast(d, TableLookupBytesOr0(B, idx_v1B)); + const V v1U = BitCast(d, TableLookupBytesOr0(C, idx_v1C)); + const V v2L = BitCast(d, TableLookupBytesOr0(A, idx_v2A)); + const V v2M = BitCast(d, TableLookupBytesOr0(B, idx_v2B)); + const V v2U = BitCast(d, TableLookupBytesOr0(C, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 8-bit lanes x8 +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + using V = VFromD; + using VU = VFromD; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + const VU idx_v0A = + Dup128VecFromValues(du, 0, 3, 6, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v0B = + Dup128VecFromValues(du, Z, Z, Z, 1, 4, 7, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v0C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1A = + Dup128VecFromValues(du, 1, 4, 7, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1B = + Dup128VecFromValues(du, Z, Z, Z, 2, 5, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 0, 3, 6, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2A = + Dup128VecFromValues(du, 2, 5, Z, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2B = + Dup128VecFromValues(du, Z, Z, 0, 3, 6, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 1, 4, 7, 0, 0, 0, 0, 0, 0, 0, 0); + const V v0L = BitCast(d, TableLookupBytesOr0(A, idx_v0A)); + const V v0M = BitCast(d, TableLookupBytesOr0(B, idx_v0B)); + const V v0U = BitCast(d, TableLookupBytesOr0(C, idx_v0C)); + const V v1L = BitCast(d, TableLookupBytesOr0(A, idx_v1A)); + const V v1M = BitCast(d, TableLookupBytesOr0(B, idx_v1B)); + const V v1U = BitCast(d, TableLookupBytesOr0(C, idx_v1C)); + const V v2L = BitCast(d, TableLookupBytesOr0(A, idx_v2A)); + const V v2M = BitCast(d, TableLookupBytesOr0(B, idx_v2B)); + const V v2U = BitCast(d, TableLookupBytesOr0(C, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 16-bit lanes x8 +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + const Repartition du8; + using V = VFromD; + using VU8 = VFromD; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. Same as above, + // but each element of the array contains a byte index for a byte of a lane. + constexpr uint8_t Z = 0x80; + const VU8 idx_v0A = Dup128VecFromValues(du8, 0x00, 0x01, 0x06, 0x07, 0x0C, + 0x0D, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v0B = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, 0x02, 0x03, + 0x08, 0x09, 0x0E, 0x0F, Z, Z, Z, Z); + const VU8 idx_v0C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + Z, 0x04, 0x05, 0x0A, 0x0B); + const VU8 idx_v1A = Dup128VecFromValues(du8, 0x02, 0x03, 0x08, 0x09, 0x0E, + 0x0F, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v1B = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, 0x04, 0x05, + 0x0A, 0x0B, Z, Z, Z, Z, Z, Z); + const VU8 idx_v1C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + 0x00, 0x01, 0x06, 0x07, 0x0C, 0x0D); + const VU8 idx_v2A = Dup128VecFromValues(du8, 0x04, 0x05, 0x0A, 0x0B, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v2B = Dup128VecFromValues(du8, Z, Z, Z, Z, 0x00, 0x01, 0x06, + 0x07, 0x0C, 0x0D, Z, Z, Z, Z, Z, Z); + const VU8 idx_v2C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + 0x02, 0x03, 0x08, 0x09, 0x0E, 0x0F); + const V v0L = TableLookupBytesOr0(A, BitCast(d, idx_v0A)); + const V v0M = TableLookupBytesOr0(B, BitCast(d, idx_v0B)); + const V v0U = TableLookupBytesOr0(C, BitCast(d, idx_v0C)); + const V v1L = TableLookupBytesOr0(A, BitCast(d, idx_v1A)); + const V v1M = TableLookupBytesOr0(B, BitCast(d, idx_v1B)); + const V v1U = TableLookupBytesOr0(C, BitCast(d, idx_v1C)); + const V v2L = TableLookupBytesOr0(A, BitCast(d, idx_v2A)); + const V v2M = TableLookupBytesOr0(B, BitCast(d, idx_v2B)); + const V v2U = TableLookupBytesOr0(C, BitCast(d, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + using V = VFromD; + V A; // v0[1] v2[0] v1[0] v0[0] + V B; // v1[2] v0[2] v2[1] v1[1] + V C; // v2[3] v1[3] v0[3] v2[2] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + + const V vxx_02_03_xx = OddEven(C, B); + v0 = detail::ShuffleTwo1230(A, vxx_02_03_xx); + + // Shuffle2301 takes the upper/lower halves of the output from one input, so + // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use + // OddEven because it may have higher throughput than Shuffle. + const V vxx_xx_10_11 = OddEven(A, B); + const V v12_13_xx_xx = OddEven(B, C); + v1 = detail::ShuffleTwo2301(vxx_xx_10_11, v12_13_xx_xx); + + const V vxx_20_21_xx = OddEven(B, A); + v2 = detail::ShuffleTwo3012(vxx_20_21_xx, C); +} + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + VFromD A; // v1[0] v0[0] + VFromD B; // v0[1] v2[0] + VFromD C; // v2[1] v1[1] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + v0 = OddEven(B, A); + v1 = CombineShiftRightBytes)>(d, C, A); + v2 = OddEven(C, B); +} + +template , HWY_IF_LANES_D(D, 1)> +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +// ------------------------------ LoadInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void LoadTransposedBlocks4(D d, + const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, + VFromD& vC, VFromD& vD) { + constexpr size_t kN = MaxLanes(d); + vA = LoadU(d, unaligned + 0 * kN); + vB = LoadU(d, unaligned + 1 * kN); + vC = LoadU(d, unaligned + 2 * kN); + vD = LoadU(d, unaligned + 3 * kN); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + const Repartition d64; + using V64 = VFromD; + using V = VFromD; + // 16 lanes per block; the lowest four blocks are at the bottom of vA..vD. + // Here int[i] means the four interleaved values of the i-th 4-tuple and + // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). + V vA; // int[13..10] int[3..0] + V vB; // int[17..14] int[7..4] + V vC; // int[1b..18] int[b..8] + V vD; // int[1f..1c] int[f..c] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + + // For brevity, the comments only list the lower block (upper = lower + 0x10) + const V v5140 = InterleaveLower(d, vA, vB); // int[5,1,4,0] + const V vd9c8 = InterleaveLower(d, vC, vD); // int[d,9,c,8] + const V v7362 = InterleaveUpper(d, vA, vB); // int[7,3,6,2] + const V vfbea = InterleaveUpper(d, vC, vD); // int[f,b,e,a] + + const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] + const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] + const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] + const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] + + const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] + const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] + const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] + const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] + + v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); + v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); + v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); + v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + // In the last step, we interleave by half of the block size, which is usually + // 8 bytes but half that for 8-bit x8 vectors. + using TW = hwy::UnsignedFromSize; + const Repartition dw; + using VW = VFromD; + + // (Comments are for 256-bit vectors.) + // 8 lanes per block; the lowest four blocks are at the bottom of vA..vD. + VFromD vA; // v3210[9]v3210[8] v3210[1]v3210[0] + VFromD vB; // v3210[b]v3210[a] v3210[3]v3210[2] + VFromD vC; // v3210[d]v3210[c] v3210[5]v3210[4] + VFromD vD; // v3210[f]v3210[e] v3210[7]v3210[6] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + + const VFromD va820 = InterleaveLower(d, vA, vB); // v3210[a,8] v3210[2,0] + const VFromD vec64 = InterleaveLower(d, vC, vD); // v3210[e,c] v3210[6,4] + const VFromD vb931 = InterleaveUpper(d, vA, vB); // v3210[b,9] v3210[3,1] + const VFromD vfd75 = InterleaveUpper(d, vC, vD); // v3210[f,d] v3210[7,5] + + const VW v10_b830 = // v10[b..8] v10[3..0] + BitCast(dw, InterleaveLower(d, va820, vb931)); + const VW v10_fc74 = // v10[f..c] v10[7..4] + BitCast(dw, InterleaveLower(d, vec64, vfd75)); + const VW v32_b830 = // v32[b..8] v32[3..0] + BitCast(dw, InterleaveUpper(d, va820, vb931)); + const VW v32_fc74 = // v32[f..c] v32[7..4] + BitCast(dw, InterleaveUpper(d, vec64, vfd75)); + + v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); + v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); + v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); + v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + using V = VFromD; + V vA; // v3210[4] v3210[0] + V vB; // v3210[5] v3210[1] + V vC; // v3210[6] v3210[2] + V vD; // v3210[7] v3210[3] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + const V v10e = InterleaveLower(d, vA, vC); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] + const V v10o = InterleaveLower(d, vB, vD); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] + const V v32e = InterleaveUpper(d, vA, vC); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] + const V v32o = InterleaveUpper(d, vB, vD); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] + + v0 = InterleaveLower(d, v10e, v10o); + v1 = InterleaveUpper(d, v10e, v10o); + v2 = InterleaveLower(d, v32e, v32o); + v3 = InterleaveUpper(d, v32e, v32o); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + VFromD vA, vB, vC, vD; + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + v0 = InterleaveLower(d, vA, vC); + v1 = InterleaveUpper(d, vA, vC); + v2 = InterleaveLower(d, vB, vD); + v3 = InterleaveUpper(d, vB, vD); +} + +// Any T x1 +template , HWY_IF_LANES_D(D, 1)> +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks2(VFromD A, VFromD B, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(A, d, unaligned + 0 * kN); + StoreU(B, d, unaligned + 1 * kN); +} + +} // namespace detail + +// >= 128 bit vector +template +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + TFromD* HWY_RESTRICT unaligned) { + const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] + const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[kN/2] v0[kN/2] + detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +} + +// <= 64 bits +template +HWY_API void StoreInterleaved2(V part0, V part1, D d, + TFromD* HWY_RESTRICT unaligned) { + const Twice d2; + const auto v0 = ZeroExtendVector(d2, part0); + const auto v1 = ZeroExtendVector(d2, part1); + const auto v10 = InterleaveLower(d2, v0, v1); + StoreU(v10, d2, unaligned); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks3(VFromD A, VFromD B, VFromD C, + D d, TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(A, d, unaligned + 0 * kN); + StoreU(B, d, unaligned + 1 * kN); + StoreU(C, d, unaligned + 2 * kN); +} + +} // namespace detail + +// >= 128-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + using TU = TFromD; + using VU = VFromD; + const VU k5 = Set(du, TU{5}); + const VU k6 = Set(du, TU{6}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + const VFromD shuf_A0 = + Dup128VecFromValues(du, 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, 3, + 0x80, 0x80, 4, 0x80, 0x80, 5); + // Cannot reuse shuf_A0 because it contains 5. + const VFromD shuf_A1 = + Dup128VecFromValues(du, 0x80, 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, + 3, 0x80, 0x80, 4, 0x80, 0x80); + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + // cannot reuse shuf_A0 (has 5) + const VU shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const VU vA0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const VU vA1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const VU vA2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const VFromD A = BitCast(d, vA0 | vA1 | vA2); + + // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] + const VU shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. + const VU shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 + const VU shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. + const VU vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d, vB0 | vB1 | vB2); + + // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] + const VU shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. + const VU shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. + const VU shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A + const VU vC0 = TableLookupBytesOr0(v0, shuf_C0); + const VU vC1 = TableLookupBytesOr0(v1, shuf_C1); + const VU vC2 = TableLookupBytesOr0(v2, shuf_C2); + const VFromD C = BitCast(d, vC0 | vC1 | vC2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const Repartition du8; + using VU8 = VFromD; + const VU8 k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); + const VU8 k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. Note that these are byte + // indices for 16-bit lanes. + const VFromD shuf_A1 = + Dup128VecFromValues(du8, 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, 2, 3, + 0x80, 0x80, 0x80, 0x80, 4, 5); + const VFromD shuf_A2 = + Dup128VecFromValues(du8, 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, + 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80); + + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); + const VFromD A = BitCast(d, A0 | A1 | A2); + + // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] + const VU8 shuf_B0 = shuf_A1 + k3; // 5..4..3. + const VU8 shuf_B1 = shuf_A2 + k3; // ..4..3.. + const VU8 shuf_B2 = shuf_A0 + k2; // .4..3..2 + const VU8 vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU8 vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU8 vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d, vB0 | vB1 | vB2); + + // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const VU8 shuf_C0 = shuf_B1 + k3; // ..7..6.. + const VU8 shuf_C1 = shuf_B2 + k3; // .7..6..5 + const VU8 shuf_C2 = shuf_B0 + k2; // 7..6..5. + const VU8 vC0 = TableLookupBytesOr0(v0, shuf_C0); + const VU8 vC1 = TableLookupBytesOr0(v1, shuf_C1); + const VU8 vC2 = TableLookupBytesOr0(v2, shuf_C2); + const VFromD C = BitCast(d, vC0 | vC1 | vC2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + + const VFromD v10_v00 = InterleaveLower(d, v0, v1); + const VFromD v01_v20 = OddEven(v0, v2); + // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) + const VFromD A = BitCast( + d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); + + const VFromD v1_321 = ShiftRightLanes<1>(d, v1); + const VFromD v0_32 = ShiftRightLanes<2>(d, v0); + const VFromD v21_v11 = OddEven(v2, v1_321); + const VFromD v12_v02 = OddEven(v1_321, v0_32); + // B: v1[2],v0[2], v2[1],v1[1] + const VFromD B = BitCast( + d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); + + // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. + const VFromD v23_v13 = OddEven(v2, v1_321); + const VFromD v03_v22 = OddEven(v0, v2); + // C: v2[3],v1[3],v0[3], v2[2] + const VFromD C = BitCast( + d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const VFromD A = InterleaveLower(d, v0, v1); + const VFromD B = OddEven(v0, v2); + const VFromD C = InterleaveUpper(d, v1, v2); + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// 64-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and first result. + constexpr size_t kFullN = 16 / sizeof(TFromD); + const Full128 du; + using VU = VFromD; + const Full128> d_full; + const VU k5 = Set(du, uint8_t{5}); + const VU k6 = Set(du, uint8_t{6}); + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU shuf_A0 = Load(du, tbl_v0); + const VU shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) + const VU shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const VU A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const VU A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const VU A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned + 0 * kFullN); + + // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const VU shuf_B0 = shuf_A2 + k6; // ..7..6.. + const VU shuf_B1 = shuf_A0 + k5; // .7..6..5 + const VU shuf_B2 = shuf_A1 + k5; // 7..6..5. + const VU vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B{BitCast(d_full, vB0 | vB1 | vB2).raw}; + StoreU(B, d, unaligned + 1 * kFullN); +} + +// 64-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D dh, + TFromD* HWY_RESTRICT unaligned) { + const Twice d_full; + const Full128 du8; + using VU8 = VFromD; + const VU8 k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); + const VU8 k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A1 = Load(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const VU8 shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const VU8 shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); + const VFromD A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned); + + // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] + const VU8 shuf_B0 = shuf_A1 + k3; // ..3. + const VU8 shuf_B1 = shuf_A2 + k3; // .3.. + const VU8 shuf_B2 = shuf_A0 + k2; // 3..2 + const VU8 vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU8 vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU8 vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d_full, vB0 | vB1 | vB2); + StoreU(VFromD{B.raw}, dh, unaligned + MaxLanes(d_full)); +} + +// 64-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + // (same code as 128-bit vector, 64-bit lanes) + const VFromD v10_v00 = InterleaveLower(d, v0, v1); + const VFromD v01_v20 = OddEven(v0, v2); + const VFromD v21_v11 = InterleaveUpper(d, v1, v2); + constexpr size_t kN = MaxLanes(d); + StoreU(v10_v00, d, unaligned + 0 * kN); + StoreU(v01_v20, d, unaligned + 1 * kN); + StoreU(v21_v11, d, unaligned + 2 * kN); +} + +// 64-bit lanes are handled by the N=1 case below. + +// <= 32-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du; + using VU = VFromD; + const Full128> d_full; + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, + 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU shuf_A0 = Load(du, tbl_v0); + const VU shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); + const VU shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); + const VU A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 + const VU A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. + const VU A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. + const VFromD A = BitCast(d_full, A0 | A1 | A2); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// 32-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du8; + using VU8 = VFromD; + const Full128> d_full; + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + const VU8 shuf_A1 = + CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); // ...1..0. + const VU8 shuf_A0 = + CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); // ....1..0 + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// Single-element vector, any lane size: just store directly +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +// ------------------------------ StoreInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks4(VFromD vA, VFromD vB, VFromD vC, + VFromD vD, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(vA, d, unaligned + 0 * kN); + StoreU(vB, d, unaligned + 1 * kN); + StoreU(vC, d, unaligned + 2 * kN); + StoreU(vD, d, unaligned + 3 * kN); +} + +} // namespace detail + +// >= 128-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32L = ZipLower(dw, v2, v3); + const auto v10U = ZipUpper(dw, v0, v1); + const auto v32U = ZipUpper(dw, v2, v3); + // The interleaved vectors are vA, vB, vC, vD. + const VFromD vA = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 + const VFromD vB = BitCast(d, InterleaveUpper(dw, v10L, v32L)); + const VFromD vC = BitCast(d, InterleaveLower(dw, v10U, v32U)); + const VFromD vD = BitCast(d, InterleaveUpper(dw, v10U, v32U)); + detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + // The interleaved vectors are vA, vB, vC, vD. + const VFromD vA = InterleaveLower(d, v0, v1); // v1[0] v0[0] + const VFromD vB = InterleaveLower(d, v2, v3); + const VFromD vC = InterleaveUpper(d, v0, v1); + const VFromD vD = InterleaveUpper(d, v2, v3); + detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); +} + +// 64-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D /* tag */, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const RepartitionToWide dw; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); + const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); + StoreU(A, d_full, unaligned); + StoreU(B, d_full, unaligned + MaxLanes(d_full)); +} + +// 64-bit vector, 64-bit lane +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D /* tag */, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d_full, v2, v3); + StoreU(A, d_full, unaligned); + StoreU(B, d_full, unaligned + MaxLanes(d_full)); +} + +// <= 32-bit vectors +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const RepartitionToWide dw; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(v3210, d_full, buf); + CopyBytes(buf, unaligned); +} + +#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED + +// Load/StoreInterleaved for special floats. Requires HWY_GENERIC_IF_EMULATED_D +// is defined such that it is true only for types that actually require these +// generic implementations. +#if HWY_IDE || (defined(HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED) == \ + defined(HWY_TARGET_TOGGLE) && \ + defined(HWY_GENERIC_IF_EMULATED_D)) +#ifdef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#endif +#if HWY_IDE +#define HWY_GENERIC_IF_EMULATED_D(D) int +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const RebindToUnsigned du; + VFromD vu0, vu1; + LoadInterleaved2(du, detail::U16LanePointer(unaligned), vu0, vu1); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + VFromD vu0, vu1, vu2; + LoadInterleaved3(du, detail::U16LanePointer(unaligned), vu0, vu1, vu2); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); + v2 = BitCast(d, vu2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + const RebindToUnsigned du; + VFromD vu0, vu1, vu2, vu3; + LoadInterleaved4(du, detail::U16LanePointer(unaligned), vu0, vu1, vu2, vu3); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); + v2 = BitCast(d, vu2); + v3 = BitCast(d, vu3); +} + +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved2(BitCast(du, v0), BitCast(du, v1), du, + detail::U16LanePointer(unaligned)); +} + +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved3(BitCast(du, v0), BitCast(du, v1), BitCast(du, v2), du, + detail::U16LanePointer(unaligned)); +} + +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved4(BitCast(du, v0), BitCast(du, v1), BitCast(du, v2), + BitCast(du, v3), du, detail::U16LanePointer(unaligned)); +} + +#endif // HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED + +// ------------------------------ LoadN + +#if (defined(HWY_NATIVE_LOAD_N) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +namespace detail { + +template +HWY_INLINE VFromD LoadNResizeBitCast(DTo d_to, DFrom d_from, + VFromD v) { +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4, the LoadU operation will zero out any lanes of v.raw // past the first (lowest-index) Lanes(d_from) lanes of v.raw if // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true (void)d_from; return ResizeBitCast(d_to, v); #else - // On other targets such as PPC/NEON, the contents of any lanes past the first - // (lowest-index) Lanes(d_from) lanes of v.raw might be non-zero if - // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true. - return ZeroExtendResizeBitCast(d_to, d_from, v); + // On other targets such as PPC/NEON, the contents of any lanes past the first + // (lowest-index) Lanes(d_from) lanes of v.raw might be non-zero if + // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true. + return ZeroExtendResizeBitCast(d_to, d_from, v); +#endif +} + +} // namespace detail + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : Zero(d); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : no; +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 1> d1; + + if (num_lanes >= 2) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 1> d1; + + if (num_lanes >= 2) return LoadU(d, p); + if (num_lanes == 0) return no; + return InterleaveLower(ResizeBitCast(d, LoadU(d1, p)), no); +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 2> d2; + const Half d1; + + if (num_lanes >= 4) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + // Two or three lanes. + const VFromD v_lo = detail::LoadNResizeBitCast(d, d2, LoadU(d2, p)); + return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 2> d2; + + if (num_lanes >= 4) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + // Two or three lanes. + const VFromD v_lo = + ConcatUpperLower(d, no, ResizeBitCast(d, LoadU(d2, p))); + return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (num_lanes >= 8) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + const size_t leading_len = num_lanes & 4; + VFromD v_trailing = Zero(d4); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), + v_trailing_lo2); + } else { + v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); + } + + if (leading_len != 0) { + return Combine(d, v_trailing, LoadU(d4, p)); + } else { + return detail::LoadNResizeBitCast(d, d4, v_trailing); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (num_lanes >= 8) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + const size_t leading_len = num_lanes & 4; + VFromD v_trailing = ResizeBitCast(d4, no); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), + ResizeBitCast(d2, no)), + v_trailing_lo2); + } else { + v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), + ResizeBitCast(d4, v_trailing_lo2)); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); + } + + if (leading_len != 0) { + return Combine(d, v_trailing, LoadU(d4, p)); + } else { + return ConcatUpperLower(d, no, ResizeBitCast(d, v_trailing)); + } +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (num_lanes >= 16) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + const size_t leading_len = num_lanes & 12; + VFromD v_trailing = Zero(d4); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), + v_trailing_lo2); + } else { + v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); + } + + if (leading_len != 0) { + if (leading_len >= 8) { + const VFromD v_hi7 = + ((leading_len & 4) != 0) + ? Combine(d8, v_trailing, LoadU(d4, p + 8)) + : detail::LoadNResizeBitCast(d8, d4, v_trailing); + return Combine(d, v_hi7, LoadU(d8, p)); + } else { + return detail::LoadNResizeBitCast(d, d8, + Combine(d8, v_trailing, LoadU(d4, p))); + } + } else { + return detail::LoadNResizeBitCast(d, d4, v_trailing); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (num_lanes >= 16) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + const size_t leading_len = num_lanes & 12; + VFromD v_trailing = ResizeBitCast(d4, no); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), + ResizeBitCast(d2, no)), + v_trailing_lo2); + } else { + v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), + ResizeBitCast(d4, v_trailing_lo2)); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); + } + + if (leading_len != 0) { + if (leading_len >= 8) { + const VFromD v_hi7 = + ((leading_len & 4) != 0) + ? Combine(d8, v_trailing, LoadU(d4, p + 8)) + : ConcatUpperLower(d8, ResizeBitCast(d8, no), + ResizeBitCast(d8, v_trailing)); + return Combine(d, v_hi7, LoadU(d8, p)); + } else { + return ConcatUpperLower( + d, ResizeBitCast(d, no), + ResizeBitCast(d, Combine(d8, v_trailing, LoadU(d4, p)))); + } + } else { + const Repartition du32; + // lowest 4 bytes from v_trailing, next 4 from no. + const VFromD lo8 = + InterleaveLower(ResizeBitCast(du32, v_trailing), BitCast(du32, no)); + return ConcatUpperLower(d, ResizeBitCast(d, no), ResizeBitCast(d, lo8)); + } +} + +#if HWY_MAX_BYTES >= 32 + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes >= Lanes(d)) return LoadU(d, p); + + const Half dh; + const size_t half_N = Lanes(dh); + if (num_lanes <= half_N) { + return ZeroExtendVector(d, LoadN(dh, p, num_lanes)); + } else { + const VFromD v_lo = LoadU(dh, p); + const VFromD v_hi = LoadN(dh, p + half_N, num_lanes - half_N); + return Combine(d, v_hi, v_lo); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes >= Lanes(d)) return LoadU(d, p); + + const Half dh; + const size_t half_N = Lanes(dh); + const VFromD no_h = LowerHalf(no); + if (num_lanes <= half_N) { + return ConcatUpperLower(d, no, + ResizeBitCast(d, LoadNOr(no_h, dh, p, num_lanes))); + } else { + const VFromD v_lo = LoadU(dh, p); + const VFromD v_hi = + LoadNOr(no_h, dh, p + half_N, num_lanes - half_N); + return Combine(d, v_hi, v_lo); + } +} + +#endif // HWY_MAX_BYTES >= 32 + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast(d, LoadN(du, detail::U16LanePointer(p), num_lanes)); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast( + d, LoadNOr(BitCast(du, no), du, detail::U16LanePointer(p), num_lanes)); +} + +#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE + +// For SVE and non-sanitizer AVX-512; RVV has its own specialization. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { +#if HWY_MEM_OPS_MIGHT_FAULT + if (num_lanes <= 0) return Zero(d); +#endif + + return MaskedLoad(FirstN(d, num_lanes), d, p); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { +#if HWY_MEM_OPS_MIGHT_FAULT + if (num_lanes <= 0) return no; +#endif + + return MaskedLoadOr(no, FirstN(d, num_lanes), d, p); +} + +#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +#endif // HWY_NATIVE_LOAD_N + +// ------------------------------ StoreN +#if (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +namespace detail { + +template +HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { + constexpr size_t kMinShrVectBytes = HWY_TARGET_IS_NEON ? 8 : 16; + const FixedTag d_shift; + return ResizeBitCast( + dh, ShiftRightBytes(d_shift, ResizeBitCast(d_shift, v))); +} + +template +HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { + return UpperHalf(dh, v); +} + +} // namespace detail + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + StoreU(v, d, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 1) { + StoreU(v, d, p); + } else if (max_lanes_to_store == 1) { + const FixedTag, 1> d1; + StoreU(LowerHalf(d1, v), d1, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 2> d2; + const Half d1; + + if (max_lanes_to_store > 1) { + if (max_lanes_to_store >= 4) { + StoreU(v, d, p); + } else { + StoreU(ResizeBitCast(d2, v), d2, p); + if (max_lanes_to_store == 3) { + StoreU(ResizeBitCast(d1, detail::StoreNGetUpperHalf(d2, v)), d1, p + 2); + } + } + } else if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (max_lanes_to_store <= 1) { + if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } + } else if (max_lanes_to_store >= 8) { + StoreU(v, d, p); + } else if (max_lanes_to_store >= 4) { + StoreU(LowerHalf(d4, v), d4, p); + StoreN(detail::StoreNGetUpperHalf(d4, v), d4, p + 4, + max_lanes_to_store - 4); + } else { + StoreN(LowerHalf(d4, v), d4, p, max_lanes_to_store); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (max_lanes_to_store <= 1) { + if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } + } else if (max_lanes_to_store >= 16) { + StoreU(v, d, p); + } else if (max_lanes_to_store >= 8) { + StoreU(LowerHalf(d8, v), d8, p); + StoreN(detail::StoreNGetUpperHalf(d8, v), d8, p + 8, + max_lanes_to_store - 8); + } else { + StoreN(LowerHalf(d8, v), d8, p, max_lanes_to_store); + } +} + +#if HWY_MAX_BYTES >= 32 +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + if (max_lanes_to_store >= N) { + StoreU(v, d, p); + return; + } + + const Half dh; + const size_t half_N = Lanes(dh); + if (max_lanes_to_store <= half_N) { + StoreN(LowerHalf(dh, v), dh, p, max_lanes_to_store); + } else { + StoreU(LowerHalf(dh, v), dh, p); + StoreN(UpperHalf(dh, v), dh, p + half_N, max_lanes_to_store - half_N); + } +} +#endif // HWY_MAX_BYTES >= 32 + +#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + const size_t clamped_max_lanes_to_store = HWY_MIN(max_lanes_to_store, N); +#if HWY_MEM_OPS_MIGHT_FAULT + if (clamped_max_lanes_to_store == 0) return; +#endif + + BlendedStore(v, FirstN(d, clamped_max_lanes_to_store), d, p); + + detail::MaybeUnpoison(p, clamped_max_lanes_to_store); +} +#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE + +#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Scatter + +#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +template > +HWY_API void ScatterOffset(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> offset) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI offset_lanes[MaxLanes(d)]; + Store(offset, di, offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < MaxLanes(d); ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template > +HWY_API void ScatterIndex(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +template > +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, + T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask_lanes[i]) base[index_lanes[i]] = lanes[i]; + } +} + +template > +HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_store) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (i < max_lanes_to_store) base[ExtractLane(index, i)] = ExtractLane(v, i); + } +} +#else +template > +HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_store) { + MaskedScatterIndex(v, FirstN(d, max_lanes_to_store), d, base, index); +} +#endif // (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Gather + +#if (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +template > +HWY_API VFromD GatherOffset(D d, const T* HWY_RESTRICT base, + VFromD> offset) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI offset_lanes[MaxLanes(d)]; + Store(offset, di, offset_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(offset_lanes[i] >= 0); + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template > +HWY_API VFromD GatherIndex(D d, const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(di)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : T{0}; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(di)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + HWY_ALIGN T no_lanes[MaxLanes(d)]; + Store(no, d, no_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : no_lanes[i]; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + VFromD v = Zero(d); + for (size_t i = 0; i < HWY_MIN(MaxLanes(d), max_lanes_to_load); ++i) { + v = InsertLane(v, i, base[ExtractLane(index, i)]); + } + return v; +} + +template > +HWY_API VFromD GatherIndexNOr(VFromD no, D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + VFromD v = no; + for (size_t i = 0; i < HWY_MIN(MaxLanes(d), max_lanes_to_load); ++i) { + v = InsertLane(v, i, base[ExtractLane(index, i)]); + } + return v; +} +#else +template > +HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return MaskedGatherIndex(FirstN(d, max_lanes_to_load), d, base, index); +} +template > +HWY_API VFromD GatherIndexNOr(VFromD no, D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return MaskedGatherIndexOr(no, FirstN(d, max_lanes_to_load), d, base, index); +} +#endif // (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Integer AbsDiff and SumsOf8AbsDiff + +#if (defined(HWY_NATIVE_INTEGER_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +template +HWY_API V AbsDiff(V a, V b) { + return Sub(Max(a, b), Min(a, b)); +} + +#endif // HWY_NATIVE_INTEGER_ABS_DIFF + +#if (defined(HWY_NATIVE_SUMS_OF_8_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#endif + +template ), + HWY_IF_V_SIZE_GT_D(DFromV, (HWY_TARGET == HWY_SCALAR ? 0 : 4))> +HWY_API Vec>> SumsOf8AbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX3 dw; + + return BitCast(dw, SumsOf8(BitCast(du, AbsDiff(a, b)))); +} + +#endif // HWY_NATIVE_SUMS_OF_8_ABS_DIFF + +// ------------------------------ SaturatedAdd/SaturatedSub for UI32/UI64 + +#if (defined(HWY_NATIVE_I32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_NATIVE_I32_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_I64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_NATIVE_I64_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_U32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + return Add(a, Min(b, Not(a))); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + return Sub(a, Min(a, b)); +} + +#endif // HWY_NATIVE_U32_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_U64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB #endif -} -} // namespace detail +template )> +HWY_API V SaturatedAdd(V a, V b) { + return Add(a, Min(b, Not(a))); +} -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - return (num_lanes > 0) ? LoadU(d, p) : Zero(d); +template )> +HWY_API V SaturatedSub(V a, V b) { + return Sub(a, Min(a, b)); } -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - return (num_lanes > 0) ? LoadU(d, p) : no; +#endif // HWY_NATIVE_U64_SATURATED_ADDSUB + +// ------------------------------ Unsigned to signed demotions + +template , DN>>, + hwy::EnableIf<(sizeof(TFromD) < sizeof(TFromV))>* = nullptr, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD DemoteTo(DN dn, V v) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned dn_u; + + // First, do a signed to signed demotion. This will convert any values + // that are greater than hwy::HighestValue>>() to a + // negative value. + const auto i2i_demote_result = DemoteTo(dn, BitCast(di, v)); + + // Second, convert any negative values to hwy::HighestValue>() + // using an unsigned Min operation. + const auto max_signed_val = Set(dn, hwy::HighestValue>()); + + return BitCast( + dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); } -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 1> d1; +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template , DN>>, + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned dn_u; - if (num_lanes >= 2) return LoadU(d, p); - if (num_lanes == 0) return Zero(d); - return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + // First, do a signed to signed demotion. This will convert any values + // that are greater than hwy::HighestValue>>() to a + // negative value. + const auto i2i_demote_result = + ReorderDemote2To(dn, BitCast(di, a), BitCast(di, b)); + + // Second, convert any negative values to hwy::HighestValue>() + // using an unsigned Min operation. + const auto max_signed_val = Set(dn, hwy::HighestValue>()); + + return BitCast( + dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); } +#endif -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 1> d1; +// ------------------------------ PromoteLowerTo - if (num_lanes >= 2) return LoadU(d, p); - if (num_lanes == 0) return no; - return InterleaveLower(ResizeBitCast(d, LoadU(d1, p)), no); +// There is no codegen advantage for a native version of this. It is provided +// only for convenience. +template +HWY_API VFromD PromoteLowerTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, LowerHalf(dh, v)); } -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 2> d2; - const Half d1; +// ------------------------------ PromoteUpperTo - if (num_lanes >= 4) return LoadU(d, p); - if (num_lanes == 0) return Zero(d); - if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); +#if (defined(HWY_NATIVE_PROMOTE_UPPER_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif - // Two or three lanes. - const VFromD v_lo = detail::LoadNResizeBitCast(d, d2, LoadU(d2, p)); - return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +// This requires UpperHalf. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); } -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 2> d2; +#endif // HWY_TARGET != HWY_SCALAR +#endif // HWY_NATIVE_PROMOTE_UPPER_TO - if (num_lanes >= 4) return LoadU(d, p); - if (num_lanes == 0) return no; - if (num_lanes == 1) return InsertLane(no, 0, p[0]); +// ------------------------------ float16_t <-> float - // Two or three lanes. - const VFromD v_lo = - ConcatUpperLower(d, no, ResizeBitCast(d, LoadU(d2, p))); - return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +#if (defined(HWY_NATIVE_F16C) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + const Rebind du16; + using VU32 = VFromD; + + const VU32 bits16 = PromoteTo(du32, BitCast(du16, v)); + const VU32 sign = ShiftRight<15>(bits16); + const VU32 biased_exp = And(ShiftRight<10>(bits16), Set(du32, 0x1F)); + const VU32 mantissa = And(bits16, Set(du32, 0x3FF)); + const VU32 subnormal = + BitCast(du32, Mul(ConvertTo(df32, BitCast(di32, mantissa)), + Set(df32, 1.0f / 16384 / 1024))); + + const VU32 biased_exp32 = Add(biased_exp, Set(du32, 127 - 15)); + const VU32 mantissa32 = ShiftLeft<23 - 10>(mantissa); + const VU32 normal = Or(ShiftLeft<23>(biased_exp32), mantissa32); + const VU32 bits32 = IfThenElse(Eq(biased_exp, Zero(du32)), subnormal, normal); + return BitCast(df32, Or(ShiftLeft<31>(sign), bits32)); } -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 4> d4; - const Half d2; - const Half d1; +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToSigned di16; + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; - if (num_lanes >= 8) return LoadU(d, p); - if (num_lanes == 0) return Zero(d); - if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + // There are 23 fractional bits (plus the implied 1 bit) in the mantissa of + // a F32, and there are 10 fractional bits (plus the implied 1 bit) in the + // mantissa of a F16 - const size_t leading_len = num_lanes & 4; - VFromD v_trailing = Zero(d4); + // We want the unbiased exponent of round_incr[i] to be at least (-14) + 13 as + // 2^(-14) is the smallest positive normal F16 value and as we want 13 + // mantissa bits (including the implicit 1 bit) to the left of the + // F32 mantissa bits in rounded_val[i] since 23 - 10 is equal to 13 - if ((num_lanes & 2) != 0) { - const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); - if ((num_lanes & 1) != 0) { - v_trailing = Combine( - d4, - detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), - v_trailing_lo2); - } else { - v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); - } - } else if ((num_lanes & 1) != 0) { - v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); - } + // The biased exponent of round_incr[i] needs to be at least 126 as + // (-14) + 13 + 127 is equal to 126 - if (leading_len != 0) { - return Combine(d, v_trailing, LoadU(d4, p)); - } else { - return detail::LoadNResizeBitCast(d, d4, v_trailing); - } -} + // We also want to biased exponent of round_incr[i] to be less than or equal + // to 255 (which is equal to MaxExponentField()) -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 4> d4; - const Half d2; - const Half d1; + // The biased F32 exponent of round_incr is equal to + // HWY_MAX(HWY_MIN(((exp_bits[i] >> 23) & 255) + 13, 255), 126) - if (num_lanes >= 8) return LoadU(d, p); - if (num_lanes == 0) return no; - if (num_lanes == 1) return InsertLane(no, 0, p[0]); + // hi9_bits[i] is equal to the upper 9 bits of v[i] + const auto hi9_bits = ShiftRight<23>(BitCast(du32, v)); - const size_t leading_len = num_lanes & 4; - VFromD v_trailing = ResizeBitCast(d4, no); + const auto k13 = Set(du32, uint32_t{13u}); - if ((num_lanes & 2) != 0) { - const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); - if ((num_lanes & 1) != 0) { - v_trailing = Combine( - d4, - InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), - ResizeBitCast(d2, no)), - v_trailing_lo2); - } else { - v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), - ResizeBitCast(d4, v_trailing_lo2)); - } - } else if ((num_lanes & 1) != 0) { - v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); - } + // Minimum biased F32 exponent of round_incr + const auto k126 = Set(du32, uint32_t{126u}); - if (leading_len != 0) { - return Combine(d, v_trailing, LoadU(d4, p)); - } else { - return ConcatUpperLower(d, no, ResizeBitCast(d, v_trailing)); - } + // round_incr_hi9_bits[i] is equivalent to + // (hi9_bits[i] & 0x100) | + // HWY_MAX(HWY_MIN((hi9_bits[i] & 0xFF) + 13, 255), 126) + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 + const auto k255 = Set(du32, uint32_t{255u}); + const auto round_incr_hi9_bits = BitwiseIfThenElse( + k255, Max(Min(Add(And(hi9_bits, k255), k13), k255), k126), hi9_bits); +#else + // On targets other than SCALAR and EMU128, the exponent bits of hi9_bits can + // be incremented by 13 and clamped to the [13, 255] range without overflowing + // into the sign bit of hi9_bits by using U8 SaturatedAdd as there are 8 + // exponent bits in an F32 + + // U8 Max can be used on targets other than SCALAR and EMU128 to clamp + // ((hi9_bits & 0xFF) + 13) to the [126, 255] range without affecting the sign + // bit + + const Repartition du32_as_u8; + const auto round_incr_hi9_bits = BitCast( + du32, + Max(SaturatedAdd(BitCast(du32_as_u8, hi9_bits), BitCast(du32_as_u8, k13)), + BitCast(du32_as_u8, k126))); +#endif + + // (round_incr_hi9_bits >> 8) is equal to (hi9_bits >> 8), and + // (round_incr_hi9_bits & 0xFF) is equal to + // HWY_MAX(HWY_MIN((round_incr_hi9_bits & 0xFF) + 13, 255), 126) + + const auto round_incr = BitCast(df32, ShiftLeft<23>(round_incr_hi9_bits)); + + // Add round_incr[i] to v[i] to round the mantissa to the nearest F16 mantissa + // and to move the fractional bits of the resulting non-NaN mantissa down to + // the lower 10 bits of rounded_val if (v[i] + round_incr[i]) is a non-NaN + // value + const auto rounded_val = Add(v, round_incr); + + // rounded_val_bits is the bits of rounded_val as a U32 + const auto rounded_val_bits = BitCast(du32, rounded_val); + + // rounded_val[i] is known to have the same biased exponent as round_incr[i] + // as |round_incr[i]| > 2^12*|v[i]| is true if round_incr[i] is a finite + // value, round_incr[i] and v[i] both have the same sign, and |round_incr[i]| + // is either a power of 2 that is greater than or equal to 2^-1 or infinity. + + // If rounded_val[i] is a finite F32 value, then + // (rounded_val_bits[i] & 0x00000FFF) is the bit representation of the + // rounded mantissa of rounded_val[i] as a UQ2.10 fixed point number that is + // in the range [0, 2]. + + // In other words, (rounded_val_bits[i] & 0x00000FFF) is between 0 and 0x0800, + // with (rounded_val_bits[i] & 0x000003FF) being the fractional bits of the + // resulting F16 mantissa, if rounded_v[i] is a finite F32 value. + + // (rounded_val_bits[i] & 0x007FF000) == 0 is guaranteed to be true if + // rounded_val[i] is a non-NaN value + + // The biased exponent of rounded_val[i] is guaranteed to be at least 126 as + // the biased exponent of round_incr[i] is at least 126 and as both v[i] and + // round_incr[i] have the same sign bit + + // The ULP of a F32 value with a biased exponent of 126 is equal to + // 2^(126 - 127 - 23), which is equal to 2^(-24) (which is also the ULP of a + // F16 value with a biased exponent of 0 or 1 as (1 - 15 - 10) is equal to + // -24) + + // The biased exponent (before subtracting by 126) needs to be clamped to the + // [126, 157] range as 126 + 31 is equal to 157 and as 31 is the largest + // biased exponent of a F16. + + // The biased exponent of the resulting F16 value is equal to + // HWY_MIN((round_incr_hi9_bits[i] & 0xFF) + + // ((rounded_val_bits[i] >> 10) & 0xFF), 157) - 126 + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 + const auto k157Shl10 = Set(du32, static_cast(uint32_t{157u} << 10)); + auto f16_exp_bits = + Min(Add(ShiftLeft<10>(And(round_incr_hi9_bits, k255)), + And(rounded_val_bits, + Set(du32, static_cast(uint32_t{0xFFu} << 10)))), + k157Shl10); + const auto f16_result_is_inf_mask = + RebindMask(df32, Eq(f16_exp_bits, k157Shl10)); +#else + const auto k157 = Set(du32, uint32_t{157}); + auto f16_exp_bits = BitCast( + du32, + Min(SaturatedAdd(BitCast(du32_as_u8, round_incr_hi9_bits), + BitCast(du32_as_u8, ShiftRight<10>(rounded_val_bits))), + BitCast(du32_as_u8, k157))); + const auto f16_result_is_inf_mask = RebindMask(df32, Eq(f16_exp_bits, k157)); + f16_exp_bits = ShiftLeft<10>(f16_exp_bits); +#endif + + f16_exp_bits = + Sub(f16_exp_bits, Set(du32, static_cast(uint32_t{126u} << 10))); + + const auto f16_unmasked_mant_bits = + BitCast(di32, Or(IfThenZeroElse(f16_result_is_inf_mask, rounded_val), + VecFromMask(df32, IsNaN(rounded_val)))); + + const auto f16_exp_mant_bits = + OrAnd(BitCast(di32, f16_exp_bits), f16_unmasked_mant_bits, + Set(di32, int32_t{0x03FF})); + + // f16_bits_as_i32 is the F16 bits sign-extended to an I32 (with the upper 17 + // bits of f16_bits_as_i32[i] set to the sign bit of rounded_val[i]) to allow + // efficient truncation of the F16 bits to an I16 using an I32->I16 DemoteTo + // operation + const auto f16_bits_as_i32 = + OrAnd(f16_exp_mant_bits, ShiftRight<16>(BitCast(di32, rounded_val_bits)), + Set(di32, static_cast(0xFFFF8000u))); + return BitCast(df16, DemoteTo(di16, f16_bits_as_i32)); } -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 8> d8; - const Half d4; - const Half d2; - const Half d1; +#endif // HWY_NATIVE_F16C - if (num_lanes >= 16) return LoadU(d, p); - if (num_lanes == 0) return Zero(d); - if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); +// ------------------------------ F64->F16 DemoteTo +#if (defined(HWY_NATIVE_DEMOTE_F64_TO_F16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif - const size_t leading_len = num_lanes & 12; - VFromD v_trailing = Zero(d4); +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df64; + const Rebind du64; + const Rebind df32; - if ((num_lanes & 2) != 0) { - const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); - if ((num_lanes & 1) != 0) { - v_trailing = Combine( - d4, - detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), - v_trailing_lo2); - } else { - v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); - } - } else if ((num_lanes & 1) != 0) { - v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); - } + // The mantissa bits of v[i] are first rounded using round-to-odd rounding to + // the nearest F64 value that has the lower 29 bits zeroed out to ensure that + // the result is correctly rounded to a F16. + + const auto vf64_rounded = OrAnd( + And(v, + BitCast(df64, Set(du64, static_cast(0xFFFFFFFFE0000000u)))), + BitCast(df64, Add(BitCast(du64, v), + Set(du64, static_cast(0x000000001FFFFFFFu)))), + BitCast(df64, Set(du64, static_cast(0x0000000020000000ULL)))); + + return DemoteTo(df16, DemoteTo(df32, vf64_rounded)); +} +#endif // HWY_HAVE_FLOAT64 - if (leading_len != 0) { - if (leading_len >= 8) { - const VFromD v_hi7 = - ((leading_len & 4) != 0) - ? Combine(d8, v_trailing, LoadU(d4, p + 8)) - : detail::LoadNResizeBitCast(d8, d4, v_trailing); - return Combine(d, v_hi7, LoadU(d8, p)); - } else { - return detail::LoadNResizeBitCast(d, d8, - Combine(d8, v_trailing, LoadU(d4, p))); - } - } else { - return detail::LoadNResizeBitCast(d, d4, v_trailing); - } +#endif // HWY_NATIVE_DEMOTE_F64_TO_F16 + +// ------------------------------ F16->F64 PromoteTo +#if (defined(HWY_NATIVE_PROMOTE_F16_TO_F64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + return PromoteTo(df64, PromoteTo(Rebind(), v)); } +#endif // HWY_HAVE_FLOAT64 -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - const FixedTag, 8> d8; - const Half d4; - const Half d2; - const Half d1; +#endif // HWY_NATIVE_PROMOTE_F16_TO_F64 - if (num_lanes >= 16) return LoadU(d, p); - if (num_lanes == 0) return no; - if (num_lanes == 1) return InsertLane(no, 0, p[0]); +// ------------------------------ F32 to BF16 DemoteTo +#if (defined(HWY_NATIVE_DEMOTE_F32_TO_BF16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif - const size_t leading_len = num_lanes & 12; - VFromD v_trailing = ResizeBitCast(d4, no); +namespace detail { - if ((num_lanes & 2) != 0) { - const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); - if ((num_lanes & 1) != 0) { - v_trailing = Combine( - d4, - InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), - ResizeBitCast(d2, no)), - v_trailing_lo2); - } else { - v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), - ResizeBitCast(d4, v_trailing_lo2)); - } - } else if ((num_lanes & 1) != 0) { - v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); - } +// Round a F32 value to the nearest BF16 value, with the result returned as the +// rounded F32 value bitcasted to an U32 - if (leading_len != 0) { - if (leading_len >= 8) { - const VFromD v_hi7 = - ((leading_len & 4) != 0) - ? Combine(d8, v_trailing, LoadU(d4, p + 8)) - : ConcatUpperLower(d8, ResizeBitCast(d8, no), - ResizeBitCast(d8, v_trailing)); - return Combine(d, v_hi7, LoadU(d8, p)); - } else { - return ConcatUpperLower( - d, ResizeBitCast(d, no), - ResizeBitCast(d, Combine(d8, v_trailing, LoadU(d4, p)))); - } - } else { - const Repartition du32; - // lowest 4 bytes from v_trailing, next 4 from no. - const VFromD lo8 = - InterleaveLower(ResizeBitCast(du32, v_trailing), BitCast(du32, no)); - return ConcatUpperLower(d, ResizeBitCast(d, no), ResizeBitCast(d, lo8)); - } +// RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent +// NaN F32 values from being converted to an infinity +template )> +HWY_INLINE VFromD>> RoundF32ForDemoteToBF16(V v) { + const DFromV d; + const RebindToUnsigned du32; + + const auto is_non_nan = Not(IsNaN(v)); + const auto bits32 = BitCast(du32, v); + + const auto round_incr = + Add(And(ShiftRight<16>(bits32), Set(du32, uint32_t{1})), + Set(du32, uint32_t{0x7FFFu})); + return MaskedAddOr(Or(bits32, Set(du32, uint32_t{0x00400000u})), + RebindMask(du32, is_non_nan), bits32, round_incr); } -#if HWY_MAX_BYTES >= 32 +} // namespace detail -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - if (num_lanes >= Lanes(d)) return LoadU(d, p); +template +HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { + const RebindToUnsigned du16; + const Twice dt_u16; - const Half dh; - const size_t half_N = Lanes(dh); - if (num_lanes <= half_N) { - return ZeroExtendVector(d, LoadN(dh, p, num_lanes)); - } else { - const VFromD v_lo = LoadU(dh, p); - const VFromD v_hi = LoadN(dh, p + half_N, num_lanes - half_N); - return Combine(d, v_hi, v_lo); - } + const auto rounded_bits = BitCast(dt_u16, detail::RoundF32ForDemoteToBF16(v)); +#if HWY_IS_LITTLE_ENDIAN + return BitCast( + dbf16, LowerHalf(du16, ConcatOdd(dt_u16, rounded_bits, rounded_bits))); +#else + return BitCast( + dbf16, LowerHalf(du16, ConcatEven(dt_u16, rounded_bits, rounded_bits))); +#endif } -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { - if (num_lanes >= Lanes(d)) return LoadU(d, p); +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; - const Half dh; - const size_t half_N = Lanes(dh); - const VFromD no_h = LowerHalf(no); - if (num_lanes <= half_N) { - return ConcatUpperLower(d, no, - ResizeBitCast(d, LoadNOr(no_h, dh, p, num_lanes))); - } else { - const VFromD v_lo = LoadU(dh, p); - const VFromD v_hi = - LoadNOr(no_h, dh, p + half_N, num_lanes - half_N); - return Combine(d, v_hi, v_lo); - } + const auto rounded_a_bits32 = + BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); + const auto rounded_b_bits32 = + BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); +#if HWY_IS_LITTLE_ENDIAN + return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, rounded_b_bits32), + BitCast(du16, rounded_a_bits32))); +#else + return BitCast(dbf16, ConcatEven(du16, BitCast(du16, rounded_b_bits32), + BitCast(du16, rounded_a_bits32))); +#endif } -#endif // HWY_MAX_BYTES >= 32 -#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; -// For SVE and non-sanitizer AVX-512; RVV has its own specialization. -template -HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { -#if HWY_MEM_OPS_MIGHT_FAULT - if (num_lanes <= 0) return Zero(d); +#if HWY_IS_LITTLE_ENDIAN + const auto a_in_odd = detail::RoundF32ForDemoteToBF16(a); + const auto b_in_even = ShiftRight<16>(detail::RoundF32ForDemoteToBF16(b)); +#else + const auto a_in_odd = ShiftRight<16>(detail::RoundF32ForDemoteToBF16(a)); + const auto b_in_even = detail::RoundF32ForDemoteToBF16(b); #endif - return MaskedLoad(FirstN(d, num_lanes), d, p); + return BitCast(dbf16, + OddEven(BitCast(du16, a_in_odd), BitCast(du16, b_in_even))); } -template -HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, - size_t num_lanes) { -#if HWY_MEM_OPS_MIGHT_FAULT - if (num_lanes <= 0) return no; +#endif // HWY_NATIVE_DEMOTE_F32_TO_BF16 + +// ------------------------------ PromoteInRangeTo +#if (defined(HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO #endif - return MaskedLoadOr(no, FirstN(d, num_lanes), d, p); +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD PromoteInRangeTo(D64 d64, VFromD> v) { + return PromoteTo(d64, v); } +#endif -#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE -#endif // HWY_NATIVE_LOAD_N +#endif // HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO -// ------------------------------ StoreN -#if (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_STORE_N -#undef HWY_NATIVE_STORE_N +// ------------------------------ ConvertInRangeTo +#if (defined(HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO #else -#define HWY_NATIVE_STORE_N +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO #endif -#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE -namespace detail { +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + return ConvertTo(di, v); +} -template -HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { - constexpr size_t kMinShrVectBytes = - (HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_NEON_WITHOUT_AES) ? 8 : 16; - const FixedTag d_shift; - return ResizeBitCast( - dh, ShiftRightBytes(d_shift, ResizeBitCast(d_shift, v))); +#endif // HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO + +// ------------------------------ DemoteInRangeTo +#if (defined(HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD DemoteInRangeTo(D32 d32, VFromD> v) { + return DemoteTo(d32, v); } +#endif -template -HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { - return UpperHalf(dh, v); +#endif // HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO + +// ------------------------------ PromoteInRangeLowerTo/PromoteInRangeUpperTo + +template )> +HWY_API VFromD PromoteInRangeLowerTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, v)); } -} // namespace detail +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API VFromD PromoteInRangeUpperTo(D d, V v) { +#if (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the upper half of v using PromoteInRangeTo -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - if (max_lanes_to_store > 0) { - StoreU(v, d, p); - } + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, UpperHalf(dh, v)); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the upper half of v to TFromD using + // PromoteUpperTo + return PromoteUpperTo(d, v); +#endif } +#endif // HWY_TARGET != HWY_SCALAR -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - if (max_lanes_to_store > 1) { - StoreU(v, d, p); - } else if (max_lanes_to_store == 1) { - const FixedTag, 1> d1; - StoreU(LowerHalf(d1, v), d1, p); - } +// ------------------------------ PromoteInRangeEvenTo/PromoteInRangeOddTo + +template )> +HWY_API VFromD PromoteInRangeEvenTo(D d, V v) { +#if HWY_TARGET == HWY_SCALAR + return PromoteInRangeTo(d, v); +#elif (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the even lanes of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const DFromV d_from; + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, ConcatEven(d_from, v, v))); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the even lanes of v to TFromD using + // PromoteEvenTo + return PromoteEvenTo(d, v); +#endif // HWY_TARGET == HWY_SCALAR } -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - const FixedTag, 2> d2; - const Half d1; +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API VFromD PromoteInRangeOddTo(D d, V v) { +#if (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the odd lanes of v using PromoteInRangeTo - if (max_lanes_to_store > 1) { - if (max_lanes_to_store >= 4) { - StoreU(v, d, p); - } else { - StoreU(ResizeBitCast(d2, v), d2, p); - if (max_lanes_to_store == 3) { - StoreU(ResizeBitCast(d1, detail::StoreNGetUpperHalf(d2, v)), d1, p + 2); - } - } - } else if (max_lanes_to_store == 1) { - StoreU(ResizeBitCast(d1, v), d1, p); - } + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const DFromV d_from; + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, ConcatOdd(d_from, v, v))); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the odd lanes of v to TFromD using + // PromoteOddTo + return PromoteOddTo(d, v); +#endif } +#endif // HWY_TARGET != HWY_SCALAR -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - const FixedTag, 4> d4; - const Half d2; - const Half d1; +// ------------------------------ SumsOf2 - if (max_lanes_to_store <= 1) { - if (max_lanes_to_store == 1) { - StoreU(ResizeBitCast(d1, v), d1, p); - } - } else if (max_lanes_to_store >= 8) { - StoreU(v, d, p); - } else if (max_lanes_to_store >= 4) { - StoreU(LowerHalf(d4, v), d4, p); - StoreN(detail::StoreNGetUpperHalf(d4, v), d4, p + 4, - max_lanes_to_store - 4); - } else { - StoreN(LowerHalf(d4, v), d4, p, max_lanes_to_store); - } +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + TypeTag /*type_tag*/, hwy::SizeTag /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + return Add(PromoteEvenTo(dw, v), PromoteOddTo(dw, v)); } -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - const FixedTag, 8> d8; - const Half d4; - const Half d2; - const Half d1; +} // namespace detail + +template +HWY_API VFromD>> SumsOf2(V v) { + return detail::SumsOf2(hwy::TypeTag>(), + hwy::SizeTag)>(), v); +} +#endif // HWY_TARGET != HWY_SCALAR - if (max_lanes_to_store <= 1) { - if (max_lanes_to_store == 1) { - StoreU(ResizeBitCast(d1, v), d1, p); - } - } else if (max_lanes_to_store >= 16) { - StoreU(v, d, p); - } else if (max_lanes_to_store >= 8) { - StoreU(LowerHalf(d8, v), d8, p); - StoreN(detail::StoreNGetUpperHalf(d8, v), d8, p + 8, - max_lanes_to_store - 8); - } else { - StoreN(LowerHalf(d8, v), d8, p, max_lanes_to_store); - } +// ------------------------------ SumsOf4 + +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf4( + TypeTag /*type_tag*/, hwy::SizeTag /*lane_size_tag*/, V v) { + using hwy::HWY_NAMESPACE::SumsOf2; + return SumsOf2(SumsOf2(v)); } -#if HWY_MAX_BYTES >= 32 -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - const size_t N = Lanes(d); - if (max_lanes_to_store >= N) { - StoreU(v, d, p); - return; - } +} // namespace detail - const Half dh; - const size_t half_N = Lanes(dh); - if (max_lanes_to_store <= half_N) { - StoreN(LowerHalf(dh, v), dh, p, max_lanes_to_store); - } else { - StoreU(LowerHalf(dh, v), dh, p); - StoreN(UpperHalf(dh, v), dh, p + half_N, max_lanes_to_store - half_N); - } +template +HWY_API VFromD>> SumsOf4(V v) { + return detail::SumsOf4(hwy::TypeTag>(), + hwy::SizeTag)>(), v); } -#endif // HWY_MAX_BYTES >= 32 -#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE -template > -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - const size_t N = Lanes(d); - const size_t clamped_max_lanes_to_store = HWY_MIN(max_lanes_to_store, N); -#if HWY_MEM_OPS_MIGHT_FAULT - if (clamped_max_lanes_to_store == 0) return; -#endif +// ------------------------------ OrderedTruncate2To - BlendedStore(v, FirstN(d, clamped_max_lanes_to_store), d, p); +#if HWY_IDE || \ + (defined(HWY_NATIVE_ORDERED_TRUNCATE_2_TO) == defined(HWY_TARGET_TOGGLE)) -#if HWY_MEM_OPS_MIGHT_FAULT - detail::MaybeUnpoison(p, clamped_max_lanes_to_store); +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO #endif -} -#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE -#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template ) * 2), + HWY_IF_LANES_D(DFromV>, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} +#endif // HWY_TARGET != HWY_SCALAR +#endif // HWY_NATIVE_ORDERED_TRUNCATE_2_TO -// ------------------------------ Scatter +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex -#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_SCATTER -#undef HWY_NATIVE_SCATTER +#if (defined(HWY_NATIVE_LEADING_ZERO_COUNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT #else -#define HWY_NATIVE_SCATTER +#define HWY_NATIVE_LEADING_ZERO_COUNT #endif -template > -HWY_API void ScatterOffset(VFromD v, D d, T* HWY_RESTRICT base, - VFromD> offset) { +namespace detail { + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const RebindToFloat df; +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 const RebindToSigned di; - using TI = TFromD; - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + const Repartition di16; - HWY_ALIGN T lanes[MaxLanes(d)]; - Store(v, d, lanes); + // On SSE2/SSSE3/SSE4/AVX2, do an int32_t to float conversion, followed + // by a unsigned right shift of the uint32_t bit representation of the + // floating point values by 23, followed by an int16_t Min + // operation as we are only interested in the biased exponent that would + // result from a uint32_t to float conversion. - HWY_ALIGN TI offset_lanes[MaxLanes(d)]; - Store(offset, di, offset_lanes); + // An int32_t to float vector conversion is also much more efficient on + // SSE2/SSSE3/SSE4/AVX2 than an uint32_t vector to float vector conversion + // as an uint32_t vector to float vector conversion on SSE2/SSSE3/SSE4/AVX2 + // requires multiple instructions whereas an int32_t to float vector + // conversion can be carried out using a single instruction on + // SSE2/SSSE3/SSE4/AVX2. - uint8_t* base_bytes = reinterpret_cast(base); - for (size_t i = 0; i < MaxLanes(d); ++i) { - CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); - } + const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(di, v))); + return BitCast(d, Min(BitCast(di16, ShiftRight<23>(f32_bits)), + BitCast(di16, Set(d, 158)))); +#else + const auto f32_bits = BitCast(d, ConvertTo(df, v)); + return BitCast(d, ShiftRight<23>(f32_bits)); +#endif } -template > -HWY_API void ScatterIndex(VFromD v, D d, T* HWY_RESTRICT base, - VFromD> index) { +template )> +HWY_INLINE V I32RangeU32ToF32BiasedExp(V v) { + // I32RangeU32ToF32BiasedExp is similar to UIntToF32BiasedExp, but + // I32RangeU32ToF32BiasedExp assumes that v[i] is between 0 and 2147483647. + const DFromV d; + const RebindToFloat df; +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 + const RebindToSigned d_src; +#else + const RebindToUnsigned d_src; +#endif + const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(d_src, v))); + return ShiftRight<23>(f32_bits); +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Rebind du32; + const auto f32_biased_exp_as_u32 = + I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); + return TruncateTo(d, f32_biased_exp_as_u32); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Rebind du32; + + const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); + const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); + + const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); + const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; const RebindToSigned di; - using TI = TFromD; - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + return BitCast(d, + OrderedDemote2To(di, BitCast(di32, lo_f32_biased_exp_as_u32), + BitCast(di32, hi_f32_biased_exp_as_u32))); +#else + return OrderedTruncate2To(d, lo_f32_biased_exp_as_u32, + hi_f32_biased_exp_as_u32); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR - HWY_ALIGN T lanes[MaxLanes(d)]; - Store(v, d, lanes); +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Rebind du32; + const auto f32_biased_exp_as_u32 = + I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); + return U8FromU32(f32_biased_exp_as_u32); +} - HWY_ALIGN TI index_lanes[MaxLanes(d)]; - Store(index, di, index_lanes); +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Rebind du32; + const Repartition du16; - for (size_t i = 0; i < MaxLanes(d); ++i) { - base[index_lanes[i]] = lanes[i]; - } + const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); + const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); + + const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); + const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); + +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di16; + const auto f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, lo_f32_biased_exp_as_u32), + BitCast(di32, hi_f32_biased_exp_as_u32)); + return DemoteTo(d, f32_biased_exp_as_i16); +#else + const auto f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, lo_f32_biased_exp_as_u32, hi_f32_biased_exp_as_u32); + return TruncateTo(d, f32_biased_exp_as_u16); +#endif } -template > -HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, - T* HWY_RESTRICT base, - VFromD> index) { - const RebindToSigned di; - using TI = TFromD; - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Half dq; + const Rebind du32; + const Repartition du16; - HWY_ALIGN T lanes[MaxLanes(d)]; - Store(v, d, lanes); + const auto lo_half = LowerHalf(dh, v); + const auto hi_half = UpperHalf(dh, v); - HWY_ALIGN TI index_lanes[MaxLanes(d)]; - Store(index, di, index_lanes); + const auto u32_q0 = PromoteTo(du32, LowerHalf(dq, lo_half)); + const auto u32_q1 = PromoteTo(du32, UpperHalf(dq, lo_half)); + const auto u32_q2 = PromoteTo(du32, LowerHalf(dq, hi_half)); + const auto u32_q3 = PromoteTo(du32, UpperHalf(dq, hi_half)); - HWY_ALIGN TI mask_lanes[MaxLanes(di)]; - Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + const auto f32_biased_exp_as_u32_q0 = I32RangeU32ToF32BiasedExp(u32_q0); + const auto f32_biased_exp_as_u32_q1 = I32RangeU32ToF32BiasedExp(u32_q1); + const auto f32_biased_exp_as_u32_q2 = I32RangeU32ToF32BiasedExp(u32_q2); + const auto f32_biased_exp_as_u32_q3 = I32RangeU32ToF32BiasedExp(u32_q3); - for (size_t i = 0; i < MaxLanes(d); ++i) { - if (mask_lanes[i]) base[index_lanes[i]] = lanes[i]; - } +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di16; + + const auto lo_f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q0), + BitCast(di32, f32_biased_exp_as_u32_q1)); + const auto hi_f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q2), + BitCast(di32, f32_biased_exp_as_u32_q3)); + return OrderedDemote2To(d, lo_f32_biased_exp_as_i16, + hi_f32_biased_exp_as_i16); +#else + const auto lo_f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, f32_biased_exp_as_u32_q0, f32_biased_exp_as_u32_q1); + const auto hi_f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, f32_biased_exp_as_u32_q2, f32_biased_exp_as_u32_q3); + return OrderedTruncate2To(d, lo_f32_biased_exp_as_u16, + hi_f32_biased_exp_as_u16); +#endif } +#endif // HWY_TARGET != HWY_SCALAR -#endif // (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) - -// ------------------------------ Gather - -#if (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_GATHER -#undef HWY_NATIVE_GATHER +#if HWY_TARGET == HWY_SCALAR +template +using F32ExpLzcntMinMaxRepartition = RebindToUnsigned; +#elif HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2 +template +using F32ExpLzcntMinMaxRepartition = Repartition; #else -#define HWY_NATIVE_GATHER +template +using F32ExpLzcntMinMaxRepartition = + Repartition), 4)>, D>; #endif -template > -HWY_API VFromD GatherOffset(D d, const T* HWY_RESTRICT base, - VFromD> offset) { - const RebindToSigned di; - using TI = TFromD; - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); - - HWY_ALIGN TI offset_lanes[MaxLanes(d)]; - Store(offset, di, offset_lanes); +template +using F32ExpLzcntMinMaxCmpV = VFromD>>; - HWY_ALIGN T lanes[MaxLanes(d)]; - const uint8_t* base_bytes = reinterpret_cast(base); - for (size_t i = 0; i < MaxLanes(d); ++i) { - CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); - } - return Load(d, lanes); +template +HWY_INLINE F32ExpLzcntMinMaxCmpV F32ExpLzcntMinMaxBitCast(V v) { + const DFromV d; + const F32ExpLzcntMinMaxRepartition d2; + return BitCast(d2, v); } -template > -HWY_API VFromD GatherIndex(D d, const T* HWY_RESTRICT base, - VFromD> index) { - const RebindToSigned di; - using TI = TFromD; - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); - - HWY_ALIGN TI index_lanes[MaxLanes(d)]; - Store(index, di, index_lanes); +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { +#if HWY_TARGET == HWY_SCALAR + const uint64_t u64_val = GetLane(v); + const float f32_val = static_cast(u64_val); + const uint32_t f32_bits = BitCastScalar(f32_val); + return Set(d, static_cast(f32_bits >> 23)); +#else + const Repartition du32; + const auto f32_biased_exp = UIntToF32BiasedExp(du32, BitCast(du32, v)); + const auto f32_biased_exp_adj = + IfThenZeroElse(Eq(f32_biased_exp, Zero(du32)), + BitCast(du32, Set(d, 0x0000002000000000u))); + const auto adj_f32_biased_exp = Add(f32_biased_exp, f32_biased_exp_adj); - HWY_ALIGN T lanes[MaxLanes(d)]; - for (size_t i = 0; i < MaxLanes(d); ++i) { - lanes[i] = base[index_lanes[i]]; - } - return Load(d, lanes); + return ShiftRight<32>(BitCast( + d, Max(F32ExpLzcntMinMaxBitCast(adj_f32_biased_exp), + F32ExpLzcntMinMaxBitCast(Reverse2(du32, adj_f32_biased_exp))))); +#endif } -template > -HWY_API VFromD MaskedGatherIndex(MFromD m, D d, - const T* HWY_RESTRICT base, - VFromD> index) { - const RebindToSigned di; - using TI = TFromD; - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); +template +HWY_INLINE V UIntToF32BiasedExp(V v) { + const DFromV d; + return UIntToF32BiasedExp(d, v); +} - HWY_ALIGN TI index_lanes[MaxLanes(di)]; - Store(index, di, index_lanes); +template +HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { + return v; +} - HWY_ALIGN TI mask_lanes[MaxLanes(di)]; - Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); +template +HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { + // If v[i] >= 16777216 is true, make sure that the bit at + // HighestSetBitIndex(v[i]) - 24 is zeroed out to ensure that any inexact + // conversion to single-precision floating point is rounded down. - HWY_ALIGN T lanes[MaxLanes(d)]; - for (size_t i = 0; i < MaxLanes(d); ++i) { - lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : T{0}; - } - return Load(d, lanes); + // This zeroing-out can be accomplished through the AndNot operation below. + return AndNot(ShiftRight<24>(v), v); } -#endif // (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) +} // namespace detail -// ------------------------------ ScatterN/GatherN +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; -template > -HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, - VFromD> index, - const size_t max_lanes_to_store) { - MaskedScatterIndex(v, FirstN(d, max_lanes_to_store), d, base, index); + const auto f32_biased_exp = detail::UIntToF32BiasedExp( + detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); + return BitCast(d, Sub(f32_biased_exp, Set(du, TU{127}))); } -template > -HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, - VFromD> index, - const size_t max_lanes_to_load) { - return MaskedGatherIndex(FirstN(d, max_lanes_to_load), d, base, index); -} +template +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; -// ------------------------------ Integer AbsDiff and SumsOf8AbsDiff + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto f32_biased_exp = detail::UIntToF32BiasedExp( + detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); + const auto lz_count = Sub(Set(du, TU{kNumOfBitsInT + 126}), f32_biased_exp); -#if (defined(HWY_NATIVE_INTEGER_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_INTEGER_ABS_DIFF -#undef HWY_NATIVE_INTEGER_ABS_DIFF -#else -#define HWY_NATIVE_INTEGER_ABS_DIFF -#endif + return BitCast(d, + Min(detail::F32ExpLzcntMinMaxBitCast(lz_count), + detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); +} template -HWY_API V AbsDiff(V a, V b) { - return Sub(Max(a, b), Min(a, b)); -} +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; -#endif // HWY_NATIVE_INTEGER_ABS_DIFF + const auto vi = BitCast(di, v); + const auto lowest_bit = BitCast(du, And(vi, Neg(vi))); -#if (defined(HWY_NATIVE_SUMS_OF_8_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF -#undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF -#else -#define HWY_NATIVE_SUMS_OF_8_ABS_DIFF -#endif + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto f32_biased_exp = detail::UIntToF32BiasedExp(lowest_bit); + const auto tz_count = Sub(f32_biased_exp, Set(du, TU{127})); -template ), - HWY_IF_V_SIZE_GT_D(DFromV, (HWY_TARGET == HWY_SCALAR ? 0 : 4))> -HWY_API Vec>> SumsOf8AbsDiff(V a, V b) { - return SumsOf8(AbsDiff(a, b)); + return BitCast(d, + Min(detail::F32ExpLzcntMinMaxBitCast(tz_count), + detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); } +#endif // HWY_NATIVE_LEADING_ZERO_COUNT -#endif // HWY_NATIVE_SUMS_OF_8_ABS_DIFF - -// ------------------------------ SaturatedAdd/SaturatedSub for UI32/UI64 +// ------------------------------ AESRound -#if (defined(HWY_NATIVE_I32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB -#undef HWY_NATIVE_I32_SATURATED_ADDSUB -#else -#define HWY_NATIVE_I32_SATURATED_ADDSUB -#endif +// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE -template )> -HWY_API V SaturatedAdd(V a, V b) { - const DFromV d; - const auto sum = Add(a, b); - const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); - const auto overflow_result = - Xor(BroadcastSignBit(a), Set(d, LimitsMax())); - return IfNegativeThenElse(overflow_mask, overflow_result, sum); -} +// Define for white-box testing, even if native instructions are available. +namespace detail { -template )> -HWY_API V SaturatedSub(V a, V b) { - const DFromV d; - const auto diff = Sub(a, b); - const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); - const auto overflow_result = - Xor(BroadcastSignBit(a), Set(d, LimitsMax())); - return IfNegativeThenElse(overflow_mask, overflow_result, diff); -} +// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with +// Vector Permute Instructions" and the accompanying assembly language +// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: +// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . +// +// A brute-force 256 byte table lookup can also be made constant-time, and +// possibly competitive on NEON, but this is more performance-portable +// especially for x86 and large vectors. -#endif // HWY_NATIVE_I32_SATURATED_ADDSUB +template // u8 +HWY_INLINE V SubBytesMulInverseAndAffineLookup(V state, V affine_tblL, + V affine_tblU) { + const DFromV du; + const auto mask = Set(du, uint8_t{0xF}); -#if (defined(HWY_NATIVE_I64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB -#undef HWY_NATIVE_I64_SATURATED_ADDSUB -#else -#define HWY_NATIVE_I64_SATURATED_ADDSUB -#endif + // Change polynomial basis to GF(2^4) + { + const VFromD basisL = + Dup128VecFromValues(du, 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, + 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA); + const VFromD basisU = + Dup128VecFromValues(du, 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, + 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD); + const auto sL = And(state, mask); + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto gf4L = TableLookupBytes(basisL, sL); + const auto gf4U = TableLookupBytes(basisU, sU); + state = Xor(gf4L, gf4U); + } -template )> -HWY_API V SaturatedAdd(V a, V b) { - const DFromV d; - const auto sum = Add(a, b); - const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); - const auto overflow_result = - Xor(BroadcastSignBit(a), Set(d, LimitsMax())); - return IfNegativeThenElse(overflow_mask, overflow_result, sum); -} + // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and + // cause TableLookupBytesOr0 to return 0. + const VFromD zetaInv = Dup128VecFromValues( + du, 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3); + const VFromD tbl = Dup128VecFromValues( + du, 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4); + const auto sL = And(state, mask); // L=low nibble, U=upper + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto sX = Xor(sU, sL); + const auto invL = TableLookupBytes(zetaInv, sL); + const auto invU = TableLookupBytes(tbl, sU); + const auto invX = TableLookupBytes(tbl, sX); + const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); + const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); -template )> -HWY_API V SaturatedSub(V a, V b) { - const DFromV d; - const auto diff = Sub(a, b); - const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); - const auto overflow_result = - Xor(BroadcastSignBit(a), Set(d, LimitsMax())); - return IfNegativeThenElse(overflow_mask, overflow_result, diff); + const auto affL = TableLookupBytesOr0(affine_tblL, outL); + const auto affU = TableLookupBytesOr0(affine_tblU, outU); + return Xor(affL, affU); } -#endif // HWY_NATIVE_I64_SATURATED_ADDSUB +template // u8 +HWY_INLINE V SubBytes(V state) { + const DFromV du; + // Linear skew (cannot bake 0x63 bias into the table because out* indices + // may have the infinity flag set). + const VFromD affineL = + Dup128VecFromValues(du, 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, + 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15); + const VFromD affineU = + Dup128VecFromValues(du, 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, + 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E); + return Xor(SubBytesMulInverseAndAffineLookup(state, affineL, affineU), + Set(du, uint8_t{0x63})); +} -#if (defined(HWY_NATIVE_U32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB -#undef HWY_NATIVE_U32_SATURATED_ADDSUB -#else -#define HWY_NATIVE_U32_SATURATED_ADDSUB -#endif +template // u8 +HWY_INLINE V InvSubBytes(V state) { + const DFromV du; + const VFromD gF2P4InvToGF2P8InvL = + Dup128VecFromValues(du, 0x00, 0x40, 0xF9, 0x7E, 0x53, 0xEA, 0x87, 0x13, + 0x2D, 0x3E, 0x94, 0xD4, 0xB9, 0x6D, 0xAA, 0xC7); + const VFromD gF2P4InvToGF2P8InvU = + Dup128VecFromValues(du, 0x00, 0x1D, 0x44, 0x93, 0x0F, 0x56, 0xD7, 0x12, + 0x9C, 0x8E, 0xC5, 0xD8, 0x59, 0x81, 0x4B, 0xCA); -template )> -HWY_API V SaturatedAdd(V a, V b) { - return Add(a, Min(b, Not(a))); -} + // Apply the inverse affine transformation + const auto b = Xor(Xor3(Or(ShiftLeft<1>(state), ShiftRight<7>(state)), + Or(ShiftLeft<3>(state), ShiftRight<5>(state)), + Or(ShiftLeft<6>(state), ShiftRight<2>(state))), + Set(du, uint8_t{0x05})); -template )> -HWY_API V SaturatedSub(V a, V b) { - return Sub(a, Min(a, b)); + // The GF(2^8) multiplicative inverse is computed as follows: + // - Changing the polynomial basis to GF(2^4) + // - Computing the GF(2^4) multiplicative inverse + // - Converting the GF(2^4) multiplicative inverse to the GF(2^8) + // multiplicative inverse through table lookups using the + // kGF2P4InvToGF2P8InvL and kGF2P4InvToGF2P8InvU tables + return SubBytesMulInverseAndAffineLookup(b, gF2P4InvToGF2P8InvL, + gF2P4InvToGF2P8InvU); } -#endif // HWY_NATIVE_U32_SATURATED_ADDSUB +} // namespace detail -#if (defined(HWY_NATIVE_U64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB -#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#endif // HWY_TARGET != HWY_SCALAR + +#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES #else -#define HWY_NATIVE_U64_SATURATED_ADDSUB +#define HWY_NATIVE_AES #endif -template )> -HWY_API V SaturatedAdd(V a, V b) { - return Add(a, Min(b, Not(a))); -} +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR || HWY_IDE -template )> -HWY_API V SaturatedSub(V a, V b) { - return Sub(a, Min(a, b)); +namespace detail { + +template // u8 +HWY_INLINE V ShiftRows(const V state) { + const DFromV du; + // transposed: state is column major + const VFromD shift_row = Dup128VecFromValues( + du, 0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11); + return TableLookupBytes(state, shift_row); } -#endif // HWY_NATIVE_U64_SATURATED_ADDSUB +template // u8 +HWY_INLINE V InvShiftRows(const V state) { + const DFromV du; + // transposed: state is column major + const VFromD shift_row = Dup128VecFromValues( + du, 0, 13, 10, 7, 4, 1, 14, 11, 8, 5, 2, 15, 12, 9, 6, 3); + return TableLookupBytes(state, shift_row); +} -// ------------------------------ Unsigned to signed demotions +template // u8 +HWY_INLINE V GF2P8Mod11BMulBy2(V v) { + const DFromV du; + const RebindToSigned di; // can only do signed comparisons + const auto msb = Lt(BitCast(di, v), Zero(di)); + const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, int8_t{0x1B}))); + return Xor(Add(v, v), overflow); // = v*2 in GF(2^8). +} -template , DN>>, - hwy::EnableIf<(sizeof(TFromD) < sizeof(TFromV))>* = nullptr, - HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> -HWY_API VFromD DemoteTo(DN dn, V v) { - const DFromV d; - const RebindToSigned di; - const RebindToUnsigned dn_u; +template // u8 +HWY_INLINE V MixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. + // 1 2 3 1 // d are on diagonal, no permutation needed. + // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. + // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). + const VFromD v2301 = Dup128VecFromValues( + du, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13); + const VFromD v1230 = Dup128VecFromValues( + du, 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12); + const auto d = GF2P8Mod11BMulBy2(state); // = state*2 in GF(2^8). + const auto s2301 = TableLookupBytes(state, v2301); + const auto d_s2301 = Xor(d, s2301); + const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} + const auto t1230_s3012 = TableLookupBytes(t_s2301, v1230); + return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms +} - // First, do a signed to signed demotion. This will convert any values - // that are greater than hwy::HighestValue>>() to a - // negative value. - const auto i2i_demote_result = DemoteTo(dn, BitCast(di, v)); +template // u8 +HWY_INLINE V InvMixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 14 11 13 9 + // 9 14 11 13 + // 13 9 14 11 + // 11 13 9 14 + const VFromD v2301 = Dup128VecFromValues( + du, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13); + const VFromD v1230 = Dup128VecFromValues( + du, 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12); - // Second, convert any negative values to hwy::HighestValue>() - // using an unsigned Min operation. - const auto max_signed_val = Set(dn, hwy::HighestValue>()); + const auto sx2 = GF2P8Mod11BMulBy2(state); /* = state*2 in GF(2^8) */ + const auto sx4 = GF2P8Mod11BMulBy2(sx2); /* = state*4 in GF(2^8) */ + const auto sx8 = GF2P8Mod11BMulBy2(sx4); /* = state*8 in GF(2^8) */ + const auto sx9 = Xor(sx8, state); /* = state*9 in GF(2^8) */ + const auto sx11 = Xor(sx9, sx2); /* = state*11 in GF(2^8) */ + const auto sx13 = Xor(sx9, sx4); /* = state*13 in GF(2^8) */ + const auto sx14 = Xor3(sx8, sx4, sx2); /* = state*14 in GF(2^8) */ - return BitCast( - dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); + const auto sx13_0123_sx9_1230 = Xor(sx13, TableLookupBytes(sx9, v1230)); + const auto sx14_0123_sx11_1230 = Xor(sx14, TableLookupBytes(sx11, v1230)); + const auto sx13_2301_sx9_3012 = TableLookupBytes(sx13_0123_sx9_1230, v2301); + return Xor(sx14_0123_sx11_1230, sx13_2301_sx9_3012); } -#if HWY_TARGET != HWY_SCALAR || HWY_IDE -template , DN>>, - HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), - HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> -HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { - const DFromV d; - const RebindToSigned di; - const RebindToUnsigned dn_u; +} // namespace detail - // First, do a signed to signed demotion. This will convert any values - // that are greater than hwy::HighestValue>>() to a - // negative value. - const auto i2i_demote_result = - ReorderDemote2To(dn, BitCast(di, a), BitCast(di, b)); +template // u8 +HWY_API V AESRound(V state, const V round_key) { + // Intel docs swap the first two steps, but it does not matter because + // ShiftRows is a permutation and SubBytes is independent of lane index. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = detail::MixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} - // Second, convert any negative values to hwy::HighestValue>() - // using an unsigned Min operation. - const auto max_signed_val = Set(dn, hwy::HighestValue>()); +template // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} - return BitCast( - dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); +template +HWY_API V AESInvMixColumns(V state) { + return detail::InvMixColumns(state); } -#endif -// ------------------------------ PromoteLowerTo +template // u8 +HWY_API V AESRoundInv(V state, const V round_key) { + state = detail::InvSubBytes(state); + state = detail::InvShiftRows(state); + state = detail::InvMixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} -// There is no codegen advantage for a native version of this. It is provided -// only for convenience. -template -HWY_API VFromD PromoteLowerTo(D d, V v) { - // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V - // because it cannot be deduced from D (could be either bf16 or f16). - const Rebind, decltype(d)> dh; - return PromoteTo(d, LowerHalf(dh, v)); +template // u8 +HWY_API V AESLastRoundInv(V state, const V round_key) { + // Like AESRoundInv, but without InvMixColumns. + state = detail::InvSubBytes(state); + state = detail::InvShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; } -// ------------------------------ PromoteUpperTo +template )> +HWY_API V AESKeyGenAssist(V v) { + const DFromV d; + const V rconXorMask = Dup128VecFromValues(d, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, + 0, 0, kRcon, 0, 0, 0); + const V rotWordShuffle = Dup128VecFromValues(d, 4, 5, 6, 7, 5, 6, 7, 4, 12, + 13, 14, 15, 13, 14, 15, 12); + const auto sub_word_result = detail::SubBytes(v); + const auto rot_word_result = + TableLookupBytes(sub_word_result, rotWordShuffle); + return Xor(rot_word_result, rconXorMask); +} -#if (defined(HWY_NATIVE_PROMOTE_UPPER_TO) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_PROMOTE_UPPER_TO -#undef HWY_NATIVE_PROMOTE_UPPER_TO -#else -#define HWY_NATIVE_PROMOTE_UPPER_TO -#endif +// Constant-time implementation inspired by +// https://www.bearssl.org/constanttime.html, but about half the cost because we +// use 64x64 multiplies and 128-bit XORs. +template +HWY_API V CLMulLower(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); -// This requires UpperHalf. -#if HWY_TARGET != HWY_SCALAR || HWY_IDE + auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); + auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); + auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); + auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); + m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); + m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); + m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); + m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} -template -HWY_API VFromD PromoteUpperTo(D d, V v) { - // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V - // because it cannot be deduced from D (could be either bf16 or f16). - const Rebind, decltype(d)> dh; - return PromoteTo(d, UpperHalf(dh, v)); +template +HWY_API V CLMulUpper(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); + auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); + auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); + auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); + m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); + m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); + m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); + m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); } +#endif // HWY_NATIVE_AES #endif // HWY_TARGET != HWY_SCALAR -#endif // HWY_NATIVE_PROMOTE_UPPER_TO -// ------------------------------ float16_t <-> float +// ------------------------------ PopulationCount -#if (defined(HWY_NATIVE_F16C) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_F16C -#undef HWY_NATIVE_F16C +#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT #else -#define HWY_NATIVE_F16C +#define HWY_NATIVE_POPCNT #endif -template -HWY_API VFromD PromoteTo(D df32, VFromD> v) { - const RebindToSigned di32; - const RebindToUnsigned du32; - const Rebind du16; - using VU32 = VFromD; +// This overload requires vectors to be at least 16 bytes, which is the case +// for LMUL >= 2. +#undef HWY_IF_POPCNT +#if HWY_TARGET == HWY_RVV +#define HWY_IF_POPCNT(D) \ + hwy::EnableIf= 1 && D().MaxLanes() >= 16>* = nullptr +#else +// Other targets only have these two overloads which are mutually exclusive, so +// no further conditions are required. +#define HWY_IF_POPCNT(D) void* = nullptr +#endif // HWY_TARGET == HWY_RVV - const VU32 bits16 = PromoteTo(du32, BitCast(du16, v)); - const VU32 sign = ShiftRight<15>(bits16); - const VU32 biased_exp = And(ShiftRight<10>(bits16), Set(du32, 0x1F)); - const VU32 mantissa = And(bits16, Set(du32, 0x3FF)); - const VU32 subnormal = - BitCast(du32, Mul(ConvertTo(df32, BitCast(di32, mantissa)), - Set(df32, 1.0f / 16384 / 1024))); +template , HWY_IF_U8_D(D), + HWY_IF_V_SIZE_GT_D(D, 8), HWY_IF_POPCNT(D)> +HWY_API V PopulationCount(V v) { + const D d; + const V lookup = + Dup128VecFromValues(d, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4); + const auto lo = And(v, Set(d, uint8_t{0xF})); + const auto hi = ShiftRight<4>(v); + return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); +} - const VU32 biased_exp32 = Add(biased_exp, Set(du32, 127 - 15)); - const VU32 mantissa32 = ShiftLeft<23 - 10>(mantissa); - const VU32 normal = Or(ShiftLeft<23>(biased_exp32), mantissa32); - const VU32 bits32 = IfThenElse(Eq(biased_exp, Zero(du32)), subnormal, normal); - return BitCast(df32, Or(ShiftLeft<31>(sign), bits32)); +// RVV has a specialization that avoids the Set(). +#if HWY_TARGET != HWY_RVV +// Slower fallback for capped vectors. +template , HWY_IF_U8_D(D), + HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API V PopulationCount(V v) { + const D d; + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + const V k33 = Set(d, uint8_t{0x33}); + v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55}))); + v = Add(And(ShiftRight<2>(v), k33), And(v, k33)); + return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F})); } +#endif // HWY_TARGET != HWY_RVV -template -HWY_API VFromD DemoteTo(D df16, VFromD> v) { - const RebindToUnsigned du16; - const Rebind du32; - const RebindToSigned di32; - using VU32 = VFromD; - using VI32 = VFromD; - - const VU32 bits32 = BitCast(du32, v); - const VU32 sign = ShiftRight<31>(bits32); - const VU32 biased_exp32 = And(ShiftRight<23>(bits32), Set(du32, 0xFF)); - const VU32 mantissa32 = And(bits32, Set(du32, 0x7FFFFF)); - - const VI32 k15 = Set(di32, 15); - const VI32 exp = Min(Sub(BitCast(di32, biased_exp32), Set(di32, 127)), k15); - const MFromD is_tiny = Lt(exp, Set(di32, -24)); - - const MFromD is_subnormal = Lt(exp, Set(di32, -14)); - const VU32 biased_exp16 = - BitCast(du32, IfThenZeroElse(is_subnormal, Add(exp, k15))); - const VU32 sub_exp = BitCast(du32, Sub(Set(di32, -14), exp)); // [1, 11) - // Clamp shift counts to prevent warnings in emu_128 Shr. - const VU32 k31 = Set(du32, 31); - const VU32 shift_m = Min(Add(Set(du32, 13), sub_exp), k31); - const VU32 shift_1 = Min(Sub(Set(du32, 10), sub_exp), k31); - const VU32 sub_m = Add(Shl(Set(du32, 1), shift_1), Shr(mantissa32, shift_m)); - const VU32 mantissa16 = IfThenElse(RebindMask(du32, is_subnormal), sub_m, - ShiftRight<13>(mantissa32)); // <1024 - - const VU32 sign16 = ShiftLeft<15>(sign); - const VU32 normal16 = Or3(sign16, ShiftLeft<10>(biased_exp16), mantissa16); - const VI32 bits16 = IfThenZeroElse(is_tiny, BitCast(di32, normal16)); - return BitCast(df16, DemoteTo(du16, bits16)); +template , HWY_IF_U16_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + const Repartition d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + return Add(ShiftRight<8>(vals), And(vals, Set(d, uint16_t{0xFF}))); } -#endif // HWY_NATIVE_F16C +template , HWY_IF_U32_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + Repartition d16; + auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); + return Add(ShiftRight<16>(vals), And(vals, Set(d, uint32_t{0xFF}))); +} -// ------------------------------ OrderedTruncate2To +#if HWY_HAVE_INTEGER64 +template , HWY_IF_U64_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + Repartition d32; + auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); + return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFFULL))); +} +#endif -#if HWY_IDE || \ - (defined(HWY_NATIVE_ORDERED_TRUNCATE_2_TO) == defined(HWY_TARGET_TOGGLE)) +#endif // HWY_NATIVE_POPCNT -#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO -#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +// ------------------------------ 8-bit multiplication + +#if (defined(HWY_NATIVE_MUL_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 #else -#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#define HWY_NATIVE_MUL_8 #endif -// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) -#if HWY_TARGET != HWY_SCALAR || HWY_IDE -template ) * 2), - HWY_IF_LANES_D(DFromV>, HWY_MAX_LANES_D(DFromV) * 2)> -HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { - return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +// 8 bit and fits in wider reg: promote +template +HWY_API V operator*(const V a, const V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + const RebindToUnsigned du; // TruncateTo result + const RebindToUnsigned dwu; // TruncateTo input + const VFromD mul = PromoteTo(dw, a) * PromoteTo(dw, b); + // TruncateTo is cheaper than ConcatEven. + return BitCast(d, TruncateTo(du, BitCast(dwu, mul))); } -#endif // HWY_TARGET != HWY_SCALAR -#endif // HWY_NATIVE_ORDERED_TRUNCATE_2_TO -// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex +// 8 bit full reg: promote halves +template +HWY_API V operator*(const V a, const V b) { + const DFromV d; + const Half dh; + const Twice> dw; + const VFromD a0 = PromoteTo(dw, LowerHalf(dh, a)); + const VFromD a1 = PromoteTo(dw, UpperHalf(dh, a)); + const VFromD b0 = PromoteTo(dw, LowerHalf(dh, b)); + const VFromD b1 = PromoteTo(dw, UpperHalf(dh, b)); + const VFromD m0 = a0 * b0; + const VFromD m1 = a1 * b1; + return ConcatEven(d, BitCast(d, m1), BitCast(d, m0)); +} -#if (defined(HWY_NATIVE_LEADING_ZERO_COUNT) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_LEADING_ZERO_COUNT -#undef HWY_NATIVE_LEADING_ZERO_COUNT +#endif // HWY_NATIVE_MUL_8 + +// ------------------------------ 64-bit multiplication + +#if (defined(HWY_NATIVE_MUL_64) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 #else -#define HWY_NATIVE_LEADING_ZERO_COUNT +#define HWY_NATIVE_MUL_64 #endif -namespace detail { +// Single-lane i64 or u64 +template +HWY_API V operator*(V x, V y) { + const DFromV d; + using T = TFromD; + using TU = MakeUnsigned; + const TU xu = static_cast(GetLane(x)); + const TU yu = static_cast(GetLane(y)); + return Set(d, static_cast(xu * yu)); +} -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { - const RebindToFloat df; -#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 - const RebindToSigned di; - const Repartition di16; +template , HWY_IF_U64_D(D64), + HWY_IF_V_SIZE_GT_D(D64, 8)> +HWY_API V operator*(V x, V y) { + RepartitionToNarrow d32; + auto x32 = BitCast(d32, x); + auto y32 = BitCast(d32, y); + auto lolo = BitCast(d32, MulEven(x32, y32)); + auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); + auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); + auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); + return BitCast(D64{}, lolo + hi); +} +template , HWY_IF_I64_D(DI64), + HWY_IF_V_SIZE_GT_D(DI64, 8)> +HWY_API V operator*(V x, V y) { + RebindToUnsigned du64; + return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); +} - // On SSE2/SSSE3/SSE4/AVX2, do an int32_t to float conversion, followed - // by a unsigned right shift of the uint32_t bit representation of the - // floating point values by 23, followed by an int16_t Min - // operation as we are only interested in the biased exponent that would - // result from a uint32_t to float conversion. +#endif // HWY_NATIVE_MUL_64 - // An int32_t to float vector conversion is also much more efficient on - // SSE2/SSSE3/SSE4/AVX2 than an uint32_t vector to float vector conversion - // as an uint32_t vector to float vector conversion on SSE2/SSSE3/SSE4/AVX2 - // requires multiple instructions whereas an int32_t to float vector - // conversion can be carried out using a single instruction on - // SSE2/SSSE3/SSE4/AVX2. +// ------------------------------ MulAdd / NegMulAdd - const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(di, v))); - return BitCast(d, Min(BitCast(di16, ShiftRight<23>(f32_bits)), - BitCast(di16, Set(d, 158)))); +#if (defined(HWY_NATIVE_INT_FMA) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA #else - const auto f32_bits = BitCast(d, ConvertTo(df, v)); - return BitCast(d, ShiftRight<23>(f32_bits)); +#define HWY_NATIVE_INT_FMA #endif -} -template )> -HWY_INLINE V I32RangeU32ToF32BiasedExp(V v) { - // I32RangeU32ToF32BiasedExp is similar to UIntToF32BiasedExp, but - // I32RangeU32ToF32BiasedExp assumes that v[i] is between 0 and 2147483647. - const DFromV d; - const RebindToFloat df; -#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 - const RebindToSigned d_src; +#ifdef HWY_NATIVE_INT_FMSUB +#undef HWY_NATIVE_INT_FMSUB #else - const RebindToUnsigned d_src; +#define HWY_NATIVE_INT_FMSUB #endif - const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(d_src, v))); - return ShiftRight<23>(f32_bits); -} -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { - const Rebind du32; - const auto f32_biased_exp_as_u32 = - I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); - return TruncateTo(d, f32_biased_exp_as_u32); +template +HWY_API V MulAdd(V mul, V x, V add) { + return Add(Mul(mul, x), add); } -#if HWY_TARGET != HWY_SCALAR -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { - const Half dh; - const Rebind du32; +template +HWY_API V NegMulAdd(V mul, V x, V add) { + return Sub(add, Mul(mul, x)); +} - const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); - const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); +template +HWY_API V MulSub(V mul, V x, V sub) { + return Sub(Mul(mul, x), sub); +} +#endif // HWY_NATIVE_INT_FMA - const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); - const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); -#if HWY_TARGET <= HWY_SSE2 - const RebindToSigned di32; - const RebindToSigned di; - return BitCast(d, - OrderedDemote2To(di, BitCast(di32, lo_f32_biased_exp_as_u32), - BitCast(di32, hi_f32_biased_exp_as_u32))); +// ------------------------------ Integer MulSub / NegMulSub +#if (defined(HWY_NATIVE_INT_FMSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_FMSUB +#undef HWY_NATIVE_INT_FMSUB #else - return OrderedTruncate2To(d, lo_f32_biased_exp_as_u32, - hi_f32_biased_exp_as_u32); +#define HWY_NATIVE_INT_FMSUB #endif -} -#endif // HWY_TARGET != HWY_SCALAR -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { - const Rebind du32; - const auto f32_biased_exp_as_u32 = - I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); - return U8FromU32(f32_biased_exp_as_u32); +template +HWY_API V MulSub(V mul, V x, V sub) { + const DFromV d; + const RebindToSigned di; + return MulAdd(mul, x, BitCast(d, Neg(BitCast(di, sub)))); } -#if HWY_TARGET != HWY_SCALAR -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { - const Half dh; - const Rebind du32; - const Repartition du16; - - const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); - const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); +#endif // HWY_NATIVE_INT_FMSUB - const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); - const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); +template +HWY_API V NegMulSub(V mul, V x, V sub) { + const DFromV d; + const RebindToSigned di; -#if HWY_TARGET <= HWY_SSE2 - const RebindToSigned di32; - const RebindToSigned di16; - const auto f32_biased_exp_as_i16 = - OrderedDemote2To(di16, BitCast(di32, lo_f32_biased_exp_as_u32), - BitCast(di32, hi_f32_biased_exp_as_u32)); - return DemoteTo(d, f32_biased_exp_as_i16); -#else - const auto f32_biased_exp_as_u16 = OrderedTruncate2To( - du16, lo_f32_biased_exp_as_u32, hi_f32_biased_exp_as_u32); - return TruncateTo(d, f32_biased_exp_as_u16); -#endif + return BitCast(d, Neg(BitCast(di, MulAdd(mul, x, sub)))); } -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { - const Half dh; - const Half dq; - const Rebind du32; - const Repartition du16; +// ------------------------------ MulAddSub - const auto lo_half = LowerHalf(dh, v); - const auto hi_half = UpperHalf(dh, v); +// MulAddSub(mul, x, sub_or_add) for a 1-lane vector is equivalent to +// MulSub(mul, x, sub_or_add) +template , 1)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + return MulSub(mul, x, sub_or_add); +} - const auto u32_q0 = PromoteTo(du32, LowerHalf(dq, lo_half)); - const auto u32_q1 = PromoteTo(du32, UpperHalf(dq, lo_half)); - const auto u32_q2 = PromoteTo(du32, LowerHalf(dq, hi_half)); - const auto u32_q3 = PromoteTo(du32, UpperHalf(dq, hi_half)); +// MulAddSub for F16/F32/F64 vectors with 2 or more lanes on +// SSSE3/SSE4/AVX2/AVX3 is implemented in x86_128-inl.h, x86_256-inl.h, and +// x86_512-inl.h - const auto f32_biased_exp_as_u32_q0 = I32RangeU32ToF32BiasedExp(u32_q0); - const auto f32_biased_exp_as_u32_q1 = I32RangeU32ToF32BiasedExp(u32_q1); - const auto f32_biased_exp_as_u32_q2 = I32RangeU32ToF32BiasedExp(u32_q2); - const auto f32_biased_exp_as_u32_q3 = I32RangeU32ToF32BiasedExp(u32_q3); +// MulAddSub for F16/F32/F64 vectors on SVE is implemented in arm_sve-inl.h -#if HWY_TARGET <= HWY_SSE2 - const RebindToSigned di32; - const RebindToSigned di16; +// MulAddSub for integer vectors on SVE2 is implemented in arm_sve-inl.h +template +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; - const auto lo_f32_biased_exp_as_i16 = - OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q0), - BitCast(di32, f32_biased_exp_as_u32_q1)); - const auto hi_f32_biased_exp_as_i16 = - OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q2), - BitCast(di32, f32_biased_exp_as_u32_q3)); - return OrderedDemote2To(d, lo_f32_biased_exp_as_i16, - hi_f32_biased_exp_as_i16); -#else - const auto lo_f32_biased_exp_as_u16 = OrderedTruncate2To( - du16, f32_biased_exp_as_u32_q0, f32_biased_exp_as_u32_q1); - const auto hi_f32_biased_exp_as_u16 = OrderedTruncate2To( - du16, f32_biased_exp_as_u32_q2, f32_biased_exp_as_u32_q3); - return OrderedTruncate2To(d, lo_f32_biased_exp_as_u16, - hi_f32_biased_exp_as_u16); -#endif + const D d; + const Rebind d_negate; + + const auto add = + OddEven(sub_or_add, BitCast(d, Neg(BitCast(d_negate, sub_or_add)))); + return MulAdd(mul, x, add); } -#endif // HWY_TARGET != HWY_SCALAR -#if HWY_TARGET == HWY_SCALAR -template -using F32ExpLzcntMinMaxRepartition = RebindToUnsigned; -#elif HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2 -template -using F32ExpLzcntMinMaxRepartition = Repartition; +// ------------------------------ Integer division +#if (defined(HWY_NATIVE_INT_DIV) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV #else -template -using F32ExpLzcntMinMaxRepartition = - Repartition), 4)>, D>; +#define HWY_NATIVE_INT_DIV #endif -template -using F32ExpLzcntMinMaxCmpV = VFromD>>; - -template -HWY_INLINE F32ExpLzcntMinMaxCmpV F32ExpLzcntMinMaxBitCast(V v) { - const DFromV d; - const F32ExpLzcntMinMaxRepartition d2; - return BitCast(d2, v); -} +namespace detail { -template -HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { -#if HWY_TARGET == HWY_SCALAR - const uint64_t u64_val = GetLane(v); - const float f32_val = static_cast(u64_val); - uint32_t f32_bits; - CopySameSize(&f32_val, &f32_bits); - return Set(d, static_cast(f32_bits >> 23)); -#else - const Repartition du32; - const auto f32_biased_exp = UIntToF32BiasedExp(du32, BitCast(du32, v)); - const auto f32_biased_exp_adj = - IfThenZeroElse(Eq(f32_biased_exp, Zero(du32)), - BitCast(du32, Set(d, 0x0000002000000000u))); - const auto adj_f32_biased_exp = Add(f32_biased_exp, f32_biased_exp_adj); +// DemoteInRangeTo, PromoteInRangeTo, and ConvertInRangeTo are okay to use in +// the implementation of detail::IntDiv in generic_ops-inl.h as the current +// implementations of DemoteInRangeTo, PromoteInRangeTo, and ConvertInRangeTo +// will convert values that are outside of the range of TFromD by either +// saturation, truncation, or converting values that are outside of the +// destination range to LimitsMin>() (which is equal to +// static_cast>(LimitsMax>() + 1)) - return ShiftRight<32>(BitCast( - d, Max(F32ExpLzcntMinMaxBitCast(adj_f32_biased_exp), - F32ExpLzcntMinMaxBitCast(Reverse2(du32, adj_f32_biased_exp))))); -#endif +template ))> +HWY_INLINE Vec IntDivConvFloatToInt(D di, V vf) { + return ConvertInRangeTo(di, vf); } -template -HWY_INLINE V UIntToF32BiasedExp(V v) { - const DFromV d; - return UIntToF32BiasedExp(d, v); +template ))> +HWY_INLINE Vec IntDivConvIntToFloat(D df, V vi) { + return ConvertTo(df, vi); } -template -HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { - return v; +#if !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template )> +HWY_INLINE Vec IntDivConvFloatToInt(D df, V vi) { + return PromoteInRangeTo(df, vi); } -template -HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { - // If v[i] >= 16777216 is true, make sure that the bit at - // HighestSetBitIndex(v[i]) - 24 is zeroed out to ensure that any inexact - // conversion to single-precision floating point is rounded down. +// If !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 is true, then UI64->F32 +// IntDivConvIntToFloat(df, vi) returns an approximation of +// static_cast(v[i]) that is within 4 ULP of static_cast(v[i]) +template )> +HWY_INLINE Vec IntDivConvIntToFloat(D df32, V vi) { + const Twice dt_f32; - // This zeroing-out can be accomplished through the AndNot operation below. - return AndNot(ShiftRight<24>(v), v); -} + auto vf32 = + ConvertTo(dt_f32, BitCast(RebindToSigned(), vi)); -} // namespace detail +#if HWY_IS_LITTLE_ENDIAN + const auto lo_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); + auto hi_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); +#else + const auto lo_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); + auto hi_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); +#endif -template -HWY_API V HighestSetBitIndex(V v) { - const DFromV d; - const RebindToUnsigned du; - using TU = TFromD; + const RebindToSigned di32; - const auto f32_biased_exp = detail::UIntToF32BiasedExp( - detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); - return BitCast(d, Sub(f32_biased_exp, Set(du, TU{127}))); + hi_f32 = + Add(hi_f32, And(BitCast(df32, BroadcastSignBit(BitCast(di32, lo_f32))), + Set(df32, 1.0f))); + return hwy::HWY_NAMESPACE::MulAdd(hi_f32, Set(df32, 4294967296.0f), lo_f32); } -template -HWY_API V LeadingZeroCount(V v) { - const DFromV d; - const RebindToUnsigned du; - using TU = TFromD; +template )> +HWY_INLINE Vec IntDivConvIntToFloat(D df32, V vu) { + const Twice dt_f32; - constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; - const auto f32_biased_exp = detail::UIntToF32BiasedExp( - detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); - const auto lz_count = Sub(Set(du, TU{kNumOfBitsInT + 126}), f32_biased_exp); + auto vf32 = + ConvertTo(dt_f32, BitCast(RebindToUnsigned(), vu)); - return BitCast(d, - Min(detail::F32ExpLzcntMinMaxBitCast(lz_count), - detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); +#if HWY_IS_LITTLE_ENDIAN + const auto lo_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); + const auto hi_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); +#else + const auto lo_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); + const auto hi_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); +#endif + + return hwy::HWY_NAMESPACE::MulAdd(hi_f32, Set(df32, 4294967296.0f), lo_f32); } +#endif // !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template , kOrigLaneSize)> +HWY_INLINE V IntDivUsingFloatDiv(V a, V b) { + const DFromV d; + const RebindToFloat df; + + // If kOrigLaneSize < sizeof(T) is true, then a[i] and b[i] are both in the + // [LimitsMin>(), + // LimitsMax>()] range. + + // floor(|a[i] / b[i]|) <= |flt_q| < floor(|a[i] / b[i]|) + 1 is also + // guaranteed to be true if MakeFloat has at least kOrigLaneSize*8 + 1 + // mantissa bits (including the implied one bit), where flt_q is equal to + // static_cast>(a[i]) / static_cast>(b[i]), + // even in the case where the magnitude of an inexact floating point division + // result is rounded up. + + // In other words, floor(flt_q) < flt_q < ceil(flt_q) is guaranteed to be true + // if (a[i] % b[i]) != 0 is true and MakeFloat has at least + // kOrigLaneSize*8 + 1 mantissa bits (including the implied one bit), even in + // the case where the magnitude of an inexact floating point division result + // is rounded up. + + // It is okay to do conversions from MakeFloat> to TFromV using + // ConvertInRangeTo if sizeof(TFromV) > kOrigLaneSize as the result of the + // floating point division is always greater than LimitsMin>() and + // less than LimitsMax>() if sizeof(TFromV) > kOrigLaneSize and + // b[i] != 0. + +#if HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64 + // On Armv7, do division by multiplying by the ApproximateReciprocal + // to avoid unnecessary overhead as F32 Div refines the approximate + // reciprocal using 4 Newton-Raphson iterations -template -HWY_API V TrailingZeroCount(V v) { - const DFromV d; - const RebindToUnsigned du; const RebindToSigned di; - using TU = TFromD; + const RebindToUnsigned du; - const auto vi = BitCast(di, v); - const auto lowest_bit = BitCast(du, And(vi, Neg(vi))); + const auto flt_b = ConvertTo(df, b); + auto flt_recip_b = ApproximateReciprocal(flt_b); + if (kOrigLaneSize > 1) { + flt_recip_b = + Mul(flt_recip_b, ReciprocalNewtonRaphsonStep(flt_recip_b, flt_b)); + } - constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; - const auto f32_biased_exp = detail::UIntToF32BiasedExp(lowest_bit); - const auto tz_count = Sub(f32_biased_exp, Set(du, TU{127})); + auto q0 = ConvertInRangeTo(d, Mul(ConvertTo(df, a), flt_recip_b)); + const auto r0 = BitCast(di, hwy::HWY_NAMESPACE::NegMulAdd(q0, b, a)); - return BitCast(d, - Min(detail::F32ExpLzcntMinMaxBitCast(tz_count), - detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); -} -#endif // HWY_NATIVE_LEADING_ZERO_COUNT + auto r1 = r0; -// ------------------------------ AESRound + // Need to negate r1[i] if a[i] < 0 is true + if (IsSigned>()) { + r1 = IfNegativeThenNegOrUndefIfZero(BitCast(di, a), r1); + } -// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. -#if HWY_TARGET != HWY_SCALAR || HWY_IDE + // r1[i] is now equal to (a[i] < 0) ? (-r0[i]) : r0[i] -// Define for white-box testing, even if native instructions are available. -namespace detail { + auto abs_b = BitCast(du, b); + if (IsSigned>()) { + abs_b = BitCast(du, Abs(BitCast(di, abs_b))); + } -// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with -// Vector Permute Instructions" and the accompanying assembly language -// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: -// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . -// -// A brute-force 256 byte table lookup can also be made constant-time, and -// possibly competitive on NEON, but this is more performance-portable -// especially for x86 and large vectors. + // If (r1[i] < 0 || r1[i] >= abs_b[i]) is true, then set q1[i] to -1. + // Otherwise, set q1[i] to 0. -template // u8 -HWY_INLINE V SubBytesMulInverseAndAffineLookup(V state, V affine_tblL, - V affine_tblU) { - const DFromV du; - const auto mask = Set(du, uint8_t{0xF}); + // (r1[i] < 0 || r1[i] >= abs_b[i]) can be carried out using a single unsigned + // comparison as static_cast(r1[i]) >= TU(LimitsMax() + 1) >= abs_b[i] + // will be true if r1[i] < 0 is true. + auto q1 = BitCast(di, VecFromMask(du, Ge(BitCast(du, r1), abs_b))); - // Change polynomial basis to GF(2^4) - { - alignas(16) static constexpr uint8_t basisL[16] = { - 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, - 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA}; - alignas(16) static constexpr uint8_t basisU[16] = { - 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, - 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD}; - const auto sL = And(state, mask); - const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero - const auto gf4L = TableLookupBytes(LoadDup128(du, basisL), sL); - const auto gf4U = TableLookupBytes(LoadDup128(du, basisU), sU); - state = Xor(gf4L, gf4U); + // q1[i] is now equal to (r1[i] < 0 || r1[i] >= abs_b[i]) ? -1 : 0 + + // Need to negate q1[i] if r0[i] and b[i] do not have the same sign + auto q1_negate_mask = r0; + if (IsSigned>()) { + q1_negate_mask = Xor(q1_negate_mask, BitCast(di, b)); } + q1 = IfNegativeThenElse(q1_negate_mask, Neg(q1), q1); - // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and - // cause TableLookupBytesOr0 to return 0. - alignas(16) static constexpr uint8_t kZetaInv[16] = { - 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3}; - alignas(16) static constexpr uint8_t kInv[16] = { - 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4}; - const auto tbl = LoadDup128(du, kInv); - const auto sL = And(state, mask); // L=low nibble, U=upper - const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero - const auto sX = Xor(sU, sL); - const auto invL = TableLookupBytes(LoadDup128(du, kZetaInv), sL); - const auto invU = TableLookupBytes(tbl, sU); - const auto invX = TableLookupBytes(tbl, sX); - const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); - const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); + // q1[i] is now equal to (r1[i] < 0 || r1[i] >= abs_b[i]) ? + // (((r0[i] ^ b[i]) < 0) ? 1 : -1) - const auto affL = TableLookupBytesOr0(affine_tblL, outL); - const auto affU = TableLookupBytesOr0(affine_tblU, outU); - return Xor(affL, affU); + // Need to subtract q1[i] from q0[i] to get the final result + return Sub(q0, BitCast(d, q1)); +#else + // On targets other than Armv7 NEON, use F16 or F32 division as most targets + // other than Armv7 NEON have native F32 divide instructions + return ConvertInRangeTo(d, Div(ConvertTo(df, a), ConvertTo(df, b))); +#endif } -template // u8 -HWY_INLINE V SubBytes(V state) { - const DFromV du; - // Linear skew (cannot bake 0x63 bias into the table because out* indices - // may have the infinity flag set). - alignas(16) static constexpr uint8_t kAffineL[16] = { - 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, - 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15}; - alignas(16) static constexpr uint8_t kAffineU[16] = { - 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, - 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E}; - return Xor(SubBytesMulInverseAndAffineLookup(state, LoadDup128(du, kAffineL), - LoadDup128(du, kAffineU)), - Set(du, uint8_t{0x63})); -} +template , kOrigLaneSize), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 4) | (1 << 8))> +HWY_INLINE V IntDivUsingFloatDiv(V a, V b) { + // If kOrigLaneSize == sizeof(T) is true, at least two reciprocal + // multiplication steps are needed as the mantissa of MakeFloat has fewer + // than kOrigLaneSize*8 + 1 bits -template // u8 -HWY_INLINE V InvSubBytes(V state) { - const DFromV du; - alignas(16) static constexpr uint8_t kGF2P4InvToGF2P8InvL[16]{ - 0x00, 0x40, 0xF9, 0x7E, 0x53, 0xEA, 0x87, 0x13, - 0x2D, 0x3E, 0x94, 0xD4, 0xB9, 0x6D, 0xAA, 0xC7}; - alignas(16) static constexpr uint8_t kGF2P4InvToGF2P8InvU[16]{ - 0x00, 0x1D, 0x44, 0x93, 0x0F, 0x56, 0xD7, 0x12, - 0x9C, 0x8E, 0xC5, 0xD8, 0x59, 0x81, 0x4B, 0xCA}; + using T = TFromV; - // Apply the inverse affine transformation - const auto b = Xor(Xor3(Or(ShiftLeft<1>(state), ShiftRight<7>(state)), - Or(ShiftLeft<3>(state), ShiftRight<5>(state)), - Or(ShiftLeft<6>(state), ShiftRight<2>(state))), - Set(du, uint8_t{0x05})); +#if HWY_HAVE_FLOAT64 + using TF = MakeFloat; +#else + using TF = float; +#endif - // The GF(2^8) multiplicative inverse is computed as follows: - // - Changing the polynomial basis to GF(2^4) - // - Computing the GF(2^4) multiplicative inverse - // - Converting the GF(2^4) multiplicative inverse to the GF(2^8) - // multiplicative inverse through table lookups using the - // kGF2P4InvToGF2P8InvL and kGF2P4InvToGF2P8InvU tables - return SubBytesMulInverseAndAffineLookup( - b, LoadDup128(du, kGF2P4InvToGF2P8InvL), - LoadDup128(du, kGF2P4InvToGF2P8InvU)); -} + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + const Rebind df; + + if (!IsSigned()) { + // If T is unsigned, set a[i] to (a[i] >= b[i] ? 1 : 0) and set b[i] to 1 if + // b[i] > LimitsMax>() is true + + const auto one = Set(di, MakeSigned{1}); + a = BitCast( + d, IfNegativeThenElse(BitCast(di, b), + IfThenElseZero(RebindMask(di, Ge(a, b)), one), + BitCast(di, a))); + b = BitCast(d, IfNegativeThenElse(BitCast(di, b), one, BitCast(di, b))); + } -} // namespace detail + // LimitsMin() <= b[i] <= LimitsMax>() is now true -#endif // HWY_TARGET != HWY_SCALAR + const auto flt_b = IntDivConvIntToFloat(df, b); -// "Include guard": skip if native AES instructions are available. -#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_AES -#undef HWY_NATIVE_AES +#if HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64 + auto flt_recip_b = ApproximateReciprocal(flt_b); + flt_recip_b = + Mul(flt_recip_b, ReciprocalNewtonRaphsonStep(flt_recip_b, flt_b)); #else -#define HWY_NATIVE_AES + const auto flt_recip_b = Div(Set(df, TF(1.0)), flt_b); #endif -// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) -#if HWY_TARGET != HWY_SCALAR + // It is okay if the conversion of a[i] * flt_recip_b[i] to T using + // IntDivConvFloatToInt returns incorrect results in any lanes where b[i] == 0 + // as the result of IntDivUsingFloatDiv(a, b) is implementation-defined in any + // lanes where b[i] == 0. + + // If ScalarAbs(b[i]) == 1 is true, then it is possible for + // a[i] * flt_recip_b[i] to be rounded up to a value that is outside of the + // range of T. If a[i] * flt_recip_b[i] is outside of the range of T, + // IntDivConvFloatToInt will convert any values that are out of the range of T + // by either saturation, truncation, or wrapping around to LimitsMin(). + + // It is okay if the conversion of a[i] * flt_recip_b[i] to T using + // IntDivConvFloatToInt wraps around if ScalarAbs(b[i]) == 1 as r0 will have + // the correct sign if ScalarAbs(b[i]) == 1, even in the cases where the + // conversion of a[i] * flt_recip_b[i] to T using IntDivConvFloatToInt is + // truncated or wraps around. + + // If ScalarAbs(b[i]) >= 2 is true, a[i] * flt_recip_b[i] will be within the + // range of T, even in the cases where the conversion of a[i] to TF is + // rounded up or the result of multiplying a[i] by flt_recip_b[i] is rounded + // up. + + // ScalarAbs(r0[i]) will also always be less than (LimitsMax() / 2) if + // b[i] != 0, even in the cases where the conversion of a[i] * flt_recip_b[i] + // to T using IntDivConvFloatToInt is truncated or is wrapped around. + + auto q0 = + IntDivConvFloatToInt(d, Mul(IntDivConvIntToFloat(df, a), flt_recip_b)); + const auto r0 = BitCast(di, hwy::HWY_NAMESPACE::NegMulAdd(q0, b, a)); + + // If b[i] != 0 is true, r0[i] * flt_recip_b[i] is always within the range of + // T, even in the cases where the conversion of r0[i] to TF is rounded up or + // the multiplication of r0[i] by flt_recip_b[i] is rounded up. + + auto q1 = + IntDivConvFloatToInt(di, Mul(IntDivConvIntToFloat(df, r0), flt_recip_b)); + const auto r1 = hwy::HWY_NAMESPACE::NegMulAdd(q1, BitCast(di, b), r0); + + auto r3 = r1; + +#if !HWY_HAVE_FLOAT64 + // Need two additional reciprocal multiplication steps for I64/U64 vectors if + // HWY_HAVE_FLOAT64 is 0 + if (sizeof(T) == 8) { + const auto q2 = IntDivConvFloatToInt( + di, Mul(IntDivConvIntToFloat(df, r1), flt_recip_b)); + const auto r2 = hwy::HWY_NAMESPACE::NegMulAdd(q2, BitCast(di, b), r1); + + const auto q3 = IntDivConvFloatToInt( + di, Mul(IntDivConvIntToFloat(df, r2), flt_recip_b)); + r3 = hwy::HWY_NAMESPACE::NegMulAdd(q3, BitCast(di, b), r2); + + q0 = Add(q0, BitCast(d, q2)); + q1 = Add(q1, q3); + } +#endif // !HWY_HAVE_FLOAT64 -namespace detail { + auto r4 = r3; -template // u8 -HWY_INLINE V ShiftRows(const V state) { - const DFromV du; - alignas(16) static constexpr uint8_t kShiftRow[16] = { - 0, 5, 10, 15, // transposed: state is column major - 4, 9, 14, 3, // - 8, 13, 2, 7, // - 12, 1, 6, 11}; - const auto shift_row = LoadDup128(du, kShiftRow); - return TableLookupBytes(state, shift_row); -} + // Need to negate r4[i] if a[i] < 0 is true + if (IsSigned>()) { + r4 = IfNegativeThenNegOrUndefIfZero(BitCast(di, a), r4); + } -template // u8 -HWY_INLINE V InvShiftRows(const V state) { - const DFromV du; - alignas(16) static constexpr uint8_t kShiftRow[16] = { - 0, 13, 10, 7, // transposed: state is column major - 4, 1, 14, 11, // - 8, 5, 2, 15, // - 12, 9, 6, 3}; - const auto shift_row = LoadDup128(du, kShiftRow); - return TableLookupBytes(state, shift_row); -} + // r4[i] is now equal to (a[i] < 0) ? (-r3[i]) : r3[i] -template // u8 -HWY_INLINE V GF2P8Mod11BMulBy2(V v) { - const DFromV du; - const RebindToSigned di; // can only do signed comparisons - const auto msb = Lt(BitCast(di, v), Zero(di)); - const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, int8_t{0x1B}))); - return Xor(Add(v, v), overflow); // = v*2 in GF(2^8). -} + auto abs_b = BitCast(du, b); + if (IsSigned>()) { + abs_b = BitCast(du, Abs(BitCast(di, abs_b))); + } -template // u8 -HWY_INLINE V MixColumns(const V state) { - const DFromV du; - // For each column, the rows are the sum of GF(2^8) matrix multiplication by: - // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. - // 1 2 3 1 // d are on diagonal, no permutation needed. - // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. - // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). - alignas(16) static constexpr uint8_t k2301[16] = { - 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}; - alignas(16) static constexpr uint8_t k1230[16] = { - 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12}; - const auto d = GF2P8Mod11BMulBy2(state); // = state*2 in GF(2^8). - const auto s2301 = TableLookupBytes(state, LoadDup128(du, k2301)); - const auto d_s2301 = Xor(d, s2301); - const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} - const auto t1230_s3012 = TableLookupBytes(t_s2301, LoadDup128(du, k1230)); - return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms -} + // If (r4[i] < 0 || r4[i] >= abs_b[i]) is true, then set q4[i] to -1. + // Otherwise, set r4[i] to 0. -template // u8 -HWY_INLINE V InvMixColumns(const V state) { - const DFromV du; - // For each column, the rows are the sum of GF(2^8) matrix multiplication by: - // 14 11 13 9 - // 9 14 11 13 - // 13 9 14 11 - // 11 13 9 14 - alignas(16) static constexpr uint8_t k2301[16] = { - 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}; - alignas(16) static constexpr uint8_t k1230[16] = { - 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12}; - const auto v1230 = LoadDup128(du, k1230); + // (r4[i] < 0 || r4[i] >= abs_b[i]) can be carried out using a single unsigned + // comparison as static_cast(r4[i]) >= TU(LimitsMax() + 1) >= abs_b[i] + // will be true if r4[i] < 0 is true. + auto q4 = BitCast(di, VecFromMask(du, Ge(BitCast(du, r4), abs_b))); - const auto sx2 = GF2P8Mod11BMulBy2(state); /* = state*2 in GF(2^8) */ - const auto sx4 = GF2P8Mod11BMulBy2(sx2); /* = state*4 in GF(2^8) */ - const auto sx8 = GF2P8Mod11BMulBy2(sx4); /* = state*8 in GF(2^8) */ - const auto sx9 = Xor(sx8, state); /* = state*9 in GF(2^8) */ - const auto sx11 = Xor(sx9, sx2); /* = state*11 in GF(2^8) */ - const auto sx13 = Xor(sx9, sx4); /* = state*13 in GF(2^8) */ - const auto sx14 = Xor3(sx8, sx4, sx2); /* = state*14 in GF(2^8) */ + // q4[i] is now equal to (r4[i] < 0 || r4[i] >= abs_b[i]) ? -1 : 0 - const auto sx13_0123_sx9_1230 = Xor(sx13, TableLookupBytes(sx9, v1230)); - const auto sx14_0123_sx11_1230 = Xor(sx14, TableLookupBytes(sx11, v1230)); - const auto sx13_2301_sx9_3012 = - TableLookupBytes(sx13_0123_sx9_1230, LoadDup128(du, k2301)); - return Xor(sx14_0123_sx11_1230, sx13_2301_sx9_3012); -} + // Need to negate q4[i] if r3[i] and b[i] do not have the same sign + auto q4_negate_mask = r3; + if (IsSigned>()) { + q4_negate_mask = Xor(q4_negate_mask, BitCast(di, b)); + } + q4 = IfNegativeThenElse(q4_negate_mask, Neg(q4), q4); -} // namespace detail + // q4[i] is now equal to (r4[i] < 0 || r4[i] >= abs_b[i]) ? + // (((r3[i] ^ b[i]) < 0) ? 1 : -1) -template // u8 -HWY_API V AESRound(V state, const V round_key) { - // Intel docs swap the first two steps, but it does not matter because - // ShiftRows is a permutation and SubBytes is independent of lane index. - state = detail::SubBytes(state); - state = detail::ShiftRows(state); - state = detail::MixColumns(state); - state = Xor(state, round_key); // AddRoundKey - return state; + // The final result is equal to q0[i] + q1[i] - q4[i] + return Sub(Add(q0, BitCast(d, q1)), BitCast(d, q4)); } -template // u8 -HWY_API V AESLastRound(V state, const V round_key) { - // LIke AESRound, but without MixColumns. - state = detail::SubBytes(state); - state = detail::ShiftRows(state); - state = Xor(state, round_key); // AddRoundKey - return state; +template ) == 1) ? 4 : 2))> +HWY_INLINE V IntDiv(V a, V b) { + using T = TFromV; + + // If HWY_HAVE_FLOAT16 is 0, need to promote I8 to I32 and U8 to U32 + using TW = MakeWide< + If<(!HWY_HAVE_FLOAT16 && sizeof(TFromV) == 1), MakeWide, T>>; + + const DFromV d; + const Rebind dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4/AVX2/AVX3, promote to and from MakeSigned to avoid + // unnecessary overhead + const RebindToSigned dw_i; + + // On SSE2/SSSE3/SSE4/AVX2/AVX3, demote to MakeSigned if + // kOrigLaneSize < sizeof(T) to avoid unnecessary overhead + const If<(kOrigLaneSize < sizeof(T)), RebindToSigned, + decltype(d)> + d_demote_to; +#else + // On other targets, promote to TW and demote to T + const decltype(dw) dw_i; + const decltype(d) d_demote_to; +#endif + + return BitCast( + d, DemoteTo(d_demote_to, IntDivUsingFloatDiv( + PromoteTo(dw_i, a), PromoteTo(dw_i, b)))); } -template -HWY_API V AESInvMixColumns(V state) { - return detail::InvMixColumns(state); +template +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4/AVX2/AVX3, promote to and from MakeSigned to avoid + // unnecessary overhead + const RebindToSigned dw_i; + + // On SSE2/SSSE3/SSE4/AVX2/AVX3, demote to MakeSigned> if + // kOrigLaneSize < sizeof(TFromV) to avoid unnecessary overhead + const If<(kOrigLaneSize < sizeof(TFromV)), RebindToSigned, + decltype(d)> + d_demote_to; +#else + // On other targets, promote to MakeWide> and demote to TFromV + const decltype(dw) dw_i; + const decltype(d) d_demote_to; +#endif + + return BitCast(d, OrderedDemote2To( + d_demote_to, + IntDivUsingFloatDiv( + PromoteLowerTo(dw_i, a), PromoteLowerTo(dw_i, b)), + IntDivUsingFloatDiv( + PromoteUpperTo(dw_i, a), PromoteUpperTo(dw_i, b)))); } -template // u8 -HWY_API V AESRoundInv(V state, const V round_key) { - state = detail::InvSubBytes(state); - state = detail::InvShiftRows(state); - state = detail::InvMixColumns(state); - state = Xor(state, round_key); // AddRoundKey - return state; +#if !HWY_HAVE_FLOAT16 +template ), + HWY_IF_V_SIZE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3, demote from int16_t to TFromV to avoid unnecessary + // overhead + const RebindToSigned dw_i; +#else + // On other targets, demote from MakeWide> to TFromV + const decltype(dw) dw_i; +#endif + + return DemoteTo(d, + BitCast(dw_i, IntDiv<1>(PromoteTo(dw, a), PromoteTo(dw, b)))); } +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const RepartitionToWide dw; -template // u8 -HWY_API V AESLastRoundInv(V state, const V round_key) { - // Like AESRoundInv, but without InvMixColumns. - state = detail::InvSubBytes(state); - state = detail::InvShiftRows(state); - state = Xor(state, round_key); // AddRoundKey - return state; +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3, demote from int16_t to TFromV to avoid unnecessary + // overhead + const RebindToSigned dw_i; +#else + // On other targets, demote from MakeWide> to TFromV + const decltype(dw) dw_i; +#endif + + return OrderedDemote2To( + d, BitCast(dw_i, IntDiv<1>(PromoteLowerTo(dw, a), PromoteLowerTo(dw, b))), + BitCast(dw_i, IntDiv<1>(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)))); } +#endif // !HWY_HAVE_FLOAT16 -template )> -HWY_API V AESKeyGenAssist(V v) { - alignas(16) static constexpr uint8_t kRconXorMask[16] = { - 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0}; - alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { - 4, 5, 6, 7, 5, 6, 7, 4, 12, 13, 14, 15, 13, 14, 15, 12}; - const DFromV d; - const auto sub_word_result = detail::SubBytes(v); - const auto rot_word_result = - TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); - return Xor(rot_word_result, LoadDup128(d, kRconXorMask)); +template +HWY_INLINE V IntDiv(V a, V b) { + return IntDivUsingFloatDiv(a, b); } -// Constant-time implementation inspired by -// https://www.bearssl.org/constanttime.html, but about half the cost because we -// use 64x64 multiplies and 128-bit XORs. -template -HWY_API V CLMulLower(V a, V b) { - const DFromV d; - static_assert(IsSame, uint64_t>(), "V must be u64"); - const auto k1 = Set(d, 0x1111111111111111ULL); - const auto k2 = Set(d, 0x2222222222222222ULL); - const auto k4 = Set(d, 0x4444444444444444ULL); - const auto k8 = Set(d, 0x8888888888888888ULL); - const auto a0 = And(a, k1); - const auto a1 = And(a, k2); - const auto a2 = And(a, k4); - const auto a3 = And(a, k8); - const auto b0 = And(b, k1); - const auto b1 = And(b, k2); - const auto b2 = And(b, k4); - const auto b3 = And(b, k8); +#if HWY_HAVE_FLOAT64 +template ), + HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Rebind df64; - auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); - auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); - auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); - auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); - m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); - m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); - m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); - m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); - return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); + // It is okay to demote the F64 Div result to int32_t or uint32_t using + // DemoteInRangeTo as static_cast(a[i]) / static_cast(b[i]) + // will always be within the range of TFromV if b[i] != 0 and + // sizeof(TFromV) <= 4. + + return DemoteInRangeTo(d, Div(PromoteTo(df64, a), PromoteTo(df64, b))); } +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Half dh; + const Repartition df64; -template -HWY_API V CLMulUpper(V a, V b) { - const DFromV d; - static_assert(IsSame, uint64_t>(), "V must be u64"); - const auto k1 = Set(d, 0x1111111111111111ULL); - const auto k2 = Set(d, 0x2222222222222222ULL); - const auto k4 = Set(d, 0x4444444444444444ULL); - const auto k8 = Set(d, 0x8888888888888888ULL); - const auto a0 = And(a, k1); - const auto a1 = And(a, k2); - const auto a2 = And(a, k4); - const auto a3 = And(a, k8); - const auto b0 = And(b, k1); - const auto b1 = And(b, k2); - const auto b2 = And(b, k4); - const auto b3 = And(b, k8); + // It is okay to demote the F64 Div result to int32_t or uint32_t using + // DemoteInRangeTo as static_cast(a[i]) / static_cast(b[i]) + // will always be within the range of TFromV if b[i] != 0 and + // sizeof(TFromV) <= 4. - auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); - auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); - auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); - auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); - m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); - m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); - m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); - m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); - return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); + const VFromD div1 = + Div(PromoteUpperTo(df64, a), PromoteUpperTo(df64, b)); + const VFromD div0 = + Div(PromoteLowerTo(df64, a), PromoteLowerTo(df64, b)); + return Combine(d, DemoteInRangeTo(dh, div1), DemoteInRangeTo(dh, div0)); } +#endif // HWY_HAVE_FLOAT64 -#endif // HWY_NATIVE_AES -#endif // HWY_TARGET != HWY_SCALAR +template +HWY_INLINE V IntMod(V a, V b) { + return hwy::HWY_NAMESPACE::NegMulAdd(IntDiv(a, b), b, a); +} -// ------------------------------ PopulationCount +#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 +template ), + HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntMod(V a, V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + return DemoteTo(d, IntMod(PromoteTo(dw, a), PromoteTo(dw, b))); +} -// "Include guard": skip if native POPCNT-related instructions are available. -#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_POPCNT -#undef HWY_NATIVE_POPCNT -#else -#define HWY_NATIVE_POPCNT -#endif +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntMod(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To( + d, IntMod(PromoteLowerTo(dw, a), PromoteLowerTo(dw, b)), + IntMod(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b))); +} +#endif // HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || HWY_TARGET == + // HWY_WASM_EMU256 -// This overload requires vectors to be at least 16 bytes, which is the case -// for LMUL >= 2. -#undef HWY_IF_POPCNT -#if HWY_TARGET == HWY_RVV -#define HWY_IF_POPCNT(D) \ - hwy::EnableIf= 1 && D().MaxLanes() >= 16>* = nullptr -#else -// Other targets only have these two overloads which are mutually exclusive, so -// no further conditions are required. -#define HWY_IF_POPCNT(D) void* = nullptr -#endif // HWY_TARGET == HWY_RVV +} // namespace detail -template , HWY_IF_U8_D(D), - HWY_IF_V_SIZE_GT_D(D, 8), HWY_IF_POPCNT(D)> -HWY_API V PopulationCount(V v) { - const D d; - HWY_ALIGN constexpr uint8_t kLookup[16] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - }; - const auto lo = And(v, Set(d, uint8_t{0xF})); - const auto hi = ShiftRight<4>(v); - const auto lookup = LoadDup128(d, kLookup); - return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); +#if HWY_TARGET == HWY_SCALAR + +template +HWY_API Vec1 operator/(Vec1 a, Vec1 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec1 operator%(Vec1 a, Vec1 b) { + return detail::IntMod(a, b); } -// RVV has a specialization that avoids the Set(). -#if HWY_TARGET != HWY_RVV -// Slower fallback for capped vectors. -template , HWY_IF_U8_D(D), - HWY_IF_V_SIZE_LE_D(D, 8)> -HWY_API V PopulationCount(V v) { - const D d; - // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 - const V k33 = Set(d, uint8_t{0x33}); - v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55}))); - v = Add(And(ShiftRight<2>(v), k33), And(v, k33)); - return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F})); +#else // HWY_TARGET != HWY_SCALAR + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + return detail::IntDiv(a, b); } -#endif // HWY_TARGET != HWY_RVV -template , HWY_IF_U16_D(D)> -HWY_API V PopulationCount(V v) { - const D d; - const Repartition d8; - const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); - return Add(ShiftRight<8>(vals), And(vals, Set(d, uint16_t{0xFF}))); +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + return detail::IntMod(a, b); } -template , HWY_IF_U32_D(D)> -HWY_API V PopulationCount(V v) { - const D d; - Repartition d16; - auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); - return Add(ShiftRight<16>(vals), And(vals, Set(d, uint32_t{0xFF}))); +#if HWY_CAP_GE256 +template +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return detail::IntDiv(a, b); } +template +HWY_API Vec256 operator%(Vec256 a, Vec256 b) { + return detail::IntMod(a, b); +} +#endif -#if HWY_HAVE_INTEGER64 -template , HWY_IF_U64_D(D)> -HWY_API V PopulationCount(V v) { - const D d; - Repartition d32; - auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); - return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFFULL))); +#if HWY_CAP_GE512 +template +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec512 operator%(Vec512 a, Vec512 b) { + return detail::IntMod(a, b); } #endif -#endif // HWY_NATIVE_POPCNT +#endif // HWY_TARGET == HWY_SCALAR -// ------------------------------ 8-bit multiplication +#endif // HWY_NATIVE_INT_DIV -// "Include guard": skip if native 8-bit mul instructions are available. -#if (defined(HWY_NATIVE_MUL_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE -#ifdef HWY_NATIVE_MUL_8 -#undef HWY_NATIVE_MUL_8 +// ------------------------------ AverageRound + +#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 #else -#define HWY_NATIVE_MUL_8 +#define HWY_NATIVE_AVERAGE_ROUND_UI32 #endif -// 8 bit and fits in wider reg: promote -template -HWY_API V operator*(const V a, const V b) { +template )> +HWY_API V AverageRound(V a, V b) { + using T = TFromV; const DFromV d; - const Rebind>, decltype(d)> dw; - const RebindToUnsigned du; // TruncateTo result - const RebindToUnsigned dwu; // TruncateTo input - const VFromD mul = PromoteTo(dw, a) * PromoteTo(dw, b); - // TruncateTo is cheaper than ConcatEven. - return BitCast(d, TruncateTo(du, BitCast(dwu, mul))); + return Add(Add(ShiftRight<1>(a), ShiftRight<1>(b)), + And(Or(a, b), Set(d, T{1}))); } -// 8 bit full reg: promote halves -template -HWY_API V operator*(const V a, const V b) { +#endif // HWY_NATIVE_AVERAGE_ROUND_UI64 + +#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V AverageRound(V a, V b) { + using T = TFromV; const DFromV d; - const Half dh; - const Twice> dw; - const VFromD a0 = PromoteTo(dw, LowerHalf(dh, a)); - const VFromD a1 = PromoteTo(dw, UpperHalf(dh, a)); - const VFromD b0 = PromoteTo(dw, LowerHalf(dh, b)); - const VFromD b1 = PromoteTo(dw, UpperHalf(dh, b)); - const VFromD m0 = a0 * b0; - const VFromD m1 = a1 * b1; - return ConcatEven(d, BitCast(d, m1), BitCast(d, m0)); + return Add(Add(ShiftRight<1>(a), ShiftRight<1>(b)), + And(Or(a, b), Set(d, T{1}))); } +#endif -#endif // HWY_NATIVE_MUL_8 +#endif // HWY_NATIVE_AVERAGE_ROUND_UI64 -// ------------------------------ 64-bit multiplication +// ------------------------------ RoundingShiftRight (AverageRound) -// "Include guard": skip if native 64-bit mul instructions are available. -#if (defined(HWY_NATIVE_MUL_64) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE -#ifdef HWY_NATIVE_MUL_64 -#undef HWY_NATIVE_MUL_64 +#if (defined(HWY_NATIVE_ROUNDING_SHR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR #else -#define HWY_NATIVE_MUL_64 +#define HWY_NATIVE_ROUNDING_SHR #endif -// Single-lane i64 or u64 -template -HWY_API V operator*(V x, V y) { +template +HWY_API V RoundingShiftRight(V v) { + const DFromV d; + using T = TFromD; + + static_assert( + 0 <= kShiftAmt && kShiftAmt <= static_cast(sizeof(T) * 8 - 1), + "kShiftAmt is out of range"); + + constexpr int kScaleDownShrAmt = HWY_MAX(kShiftAmt - 1, 0); + + auto scaled_down_v = v; + HWY_IF_CONSTEXPR(kScaleDownShrAmt > 0) { + scaled_down_v = ShiftRight(v); + } + + HWY_IF_CONSTEXPR(kShiftAmt == 0) { return scaled_down_v; } + + return AverageRound(scaled_down_v, Zero(d)); +} + +template +HWY_API V RoundingShiftRightSame(V v, int shift_amt) { + const DFromV d; + using T = TFromD; + + const int shift_amt_is_zero_mask = -static_cast(shift_amt == 0); + + const auto scaled_down_v = ShiftRightSame( + v, static_cast(static_cast(shift_amt) + + static_cast(~shift_amt_is_zero_mask))); + + return AverageRound( + scaled_down_v, + And(scaled_down_v, Set(d, static_cast(shift_amt_is_zero_mask)))); +} + +template +HWY_API V RoundingShr(V v, V amt) { const DFromV d; + const RebindToUnsigned du; using T = TFromD; using TU = MakeUnsigned; - const TU xu = static_cast(GetLane(x)); - const TU yu = static_cast(GetLane(y)); - return Set(d, static_cast(xu * yu)); + + const auto unsigned_amt = BitCast(du, amt); + const auto scale_down_shr_amt = + BitCast(d, SaturatedSub(unsigned_amt, Set(du, TU{1}))); + + const auto scaled_down_v = Shr(v, scale_down_shr_amt); + return AverageRound(scaled_down_v, + IfThenElseZero(Eq(amt, Zero(d)), scaled_down_v)); } -template , HWY_IF_U64_D(D64), - HWY_IF_V_SIZE_GT_D(D64, 8)> -HWY_API V operator*(V x, V y) { - RepartitionToNarrow d32; - auto x32 = BitCast(d32, x); - auto y32 = BitCast(d32, y); - auto lolo = BitCast(d32, MulEven(x32, y32)); - auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); - auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); - auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); - return BitCast(D64{}, lolo + hi); +#endif // HWY_NATIVE_ROUNDING_SHR + +// ------------------------------ MulEvenAdd (PromoteEvenTo) + +// SVE with bf16 and NEON with bf16 override this. +#if (defined(HWY_NATIVE_MUL_EVEN_BF16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +template >> +HWY_API VFromD MulEvenAdd(DF df, VBF a, VBF b, VFromD c) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), c); } -template , HWY_IF_I64_D(DI64), - HWY_IF_V_SIZE_GT_D(DI64, 8)> -HWY_API V operator*(V x, V y) { - RebindToUnsigned du64; - return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); + +template >> +HWY_API VFromD MulOddAdd(DF df, VBF a, VBF b, VFromD c) { + return MulAdd(PromoteOddTo(df, a), PromoteOddTo(df, b), c); } -#endif // HWY_NATIVE_MUL_64 +#endif // HWY_NATIVE_MUL_EVEN_BF16 -// ------------------------------ MulAdd / NegMulAdd +// ------------------------------ ReorderWidenMulAccumulate (MulEvenAdd) -// "Include guard": skip if native int MulAdd instructions are available. -#if (defined(HWY_NATIVE_INT_FMA) == defined(HWY_TARGET_TOGGLE)) -#ifdef HWY_NATIVE_INT_FMA -#undef HWY_NATIVE_INT_FMA +// AVX3_SPR/ZEN4, and NEON with bf16 but not(!) SVE override this. +#if (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 #else -#define HWY_NATIVE_INT_FMA +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 #endif -template -HWY_API V MulAdd(V mul, V x, V add) { - return Add(Mul(mul, x), add); +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF df, VBF a, VBF b, + VFromD sum0, + VFromD& sum1) { + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulOddAdd(df, a, b, sum1); + return MulEvenAdd(df, a, b, sum0); } -template -HWY_API V NegMulAdd(V mul, V x, V add) { - return Sub(add, Mul(mul, x)); +#endif // HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 + +// ------------------------------ WidenMulAccumulate + +#if (defined(HWY_NATIVE_WIDEN_MUL_ACCUMULATE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#endif + +template), + class DN = RepartitionToNarrow> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high); + return MulAdd(PromoteLowerTo(d, mul), PromoteLowerTo(d, x), low); } -#endif // HWY_NATIVE_INT_FMA +#endif // HWY_NATIVE_WIDEN_MUL_ACCUMULATE + +#if 0 +#if (defined(HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif + +#if HWY_HAVE_FLOAT16 + +template> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high); + return MulAdd(PromoteLowerTo(d, mul), PromoteLowerTo(d, x), low); +} + +#endif // HWY_HAVE_FLOAT16 + +#endif // HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif // #if 0 // ------------------------------ SatWidenMulPairwiseAdd @@ -2819,17 +5113,77 @@ template SatWidenMulPairwiseAdd(DI16 di16, VU8 a, VI8 b) { const RebindToUnsigned du16; - const auto a0 = And(BitCast(di16, a), Set(di16, int16_t{0x00FF})); - const auto b0 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, b))); + const auto a0 = BitCast(di16, PromoteEvenTo(du16, a)); + const auto b0 = PromoteEvenTo(di16, b); - const auto a1 = BitCast(di16, ShiftRight<8>(BitCast(du16, a))); - const auto b1 = ShiftRight<8>(BitCast(di16, b)); + const auto a1 = BitCast(di16, PromoteOddTo(du16, a)); + const auto b1 = PromoteOddTo(di16, b); return SaturatedAdd(Mul(a0, b0), Mul(a1, b1)); } #endif +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if (defined(HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 di32, VFromD> a, + VFromD> b, VFromD sum) { + // WidenMulPairwiseAdd(di32, a, b) is okay here as + // a[0]*b[0]+a[1]*b[1] is between -2147418112 and 2147483648 and as + // a[0]*b[0]+a[1]*b[1] can only overflow an int32_t if + // a[0], b[0], a[1], and b[1] are all equal to -32768. + + const auto product = WidenMulPairwiseAdd(di32, a, b); + + const auto mul_overflow = + VecFromMask(di32, Eq(product, Set(di32, LimitsMin()))); + + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), + Add(product, mul_overflow)); +} + +#endif // HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM + +// ------------------------------ SatWidenMulAccumFixedPoint + +#if (defined(HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition dt_i16; + + const auto vt_a = ResizeBitCast(dt_i16, a); + const auto vt_b = ResizeBitCast(dt_i16, b); + + const auto dup_a = InterleaveWholeLower(dt_i16, vt_a, vt_a); + const auto dup_b = InterleaveWholeLower(dt_i16, vt_b, vt_b); + + return SatWidenMulPairwiseAccumulate(di32, dup_a, dup_b, sum); +} + +#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT + // ------------------------------ SumOfMulQuadAccumulate #if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \ @@ -2848,11 +5202,11 @@ HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, VFromD sum) { const Repartition di16; - const auto a0 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, a))); - const auto b0 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, b))); + const auto a0 = PromoteEvenTo(di16, a); + const auto b0 = PromoteEvenTo(di16, b); - const auto a1 = ShiftRight<8>(BitCast(di16, a)); - const auto b1 = ShiftRight<8>(BitCast(di16, b)); + const auto a1 = PromoteOddTo(di16, a); + const auto b1 = PromoteOddTo(di16, b); return Add(sum, Add(WidenMulPairwiseAdd(di32, a0, b0), WidenMulPairwiseAdd(di32, a1, b1))); @@ -2985,12 +5339,10 @@ HWY_API VFromD SumOfMulQuadAccumulate( const auto u32_even_prod = MulEven(a, b); const auto u32_odd_prod = MulOdd(a, b); - const auto lo32_mask = Set(du64, uint64_t{0xFFFFFFFFu}); - - const auto p0 = Add(And(BitCast(du64, u32_even_prod), lo32_mask), - And(BitCast(du64, u32_odd_prod), lo32_mask)); - const auto p1 = Add(ShiftRight<32>(BitCast(du64, u32_even_prod)), - ShiftRight<32>(BitCast(du64, u32_odd_prod))); + const auto p0 = Add(PromoteEvenTo(du64, u32_even_prod), + PromoteEvenTo(du64, u32_odd_prod)); + const auto p1 = + Add(PromoteOddTo(du64, u32_even_prod), PromoteOddTo(du64, u32_odd_prod)); return Add(sum, Add(p0, p1)); } @@ -3043,7 +5395,6 @@ HWY_API V ApproximateReciprocalSqrt(V v) { // ------------------------------ Compress* -// "Include guard": skip if native 8-bit compress instructions are available. #if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE)) #ifdef HWY_NATIVE_COMPRESS8 #undef HWY_NATIVE_COMPRESS8 @@ -3244,7 +5595,6 @@ HWY_API V CompressNot(V v, M mask) { // ------------------------------ Expand -// "Include guard": skip if native 8/16-bit Expand/LoadExpand are available. // Note that this generic implementation assumes <= 128 bit fixed vectors; // the SVE and RVV targets provide their own native implementations. #if (defined(HWY_NATIVE_EXPAND) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE @@ -3853,7 +6203,9 @@ HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { BitCast(du, InterleaveLower(du8x2, indices8, indices8)); // TableLookupBytesOr0 operates on bytes. To convert u16 lane indices to byte // indices, add 0 to even and 1 to odd byte lanes. - const Vec128 byte_indices = Add(indices16, Set(du, 0x0100)); + const Vec128 byte_indices = Add( + indices16, + Set(du, static_cast(HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001))); return BitCast(d, TableLookupBytesOr0(v, byte_indices)); } @@ -3911,9 +6263,7 @@ using IndicesFromD = decltype(IndicesFromVec(D(), Zero(RebindToUnsigned()))); // RVV/SVE have their own implementations of // TwoTablesLookupLanes(D d, VFromD a, VFromD b, IndicesFromD idx) -#if HWY_TARGET != HWY_RVV && HWY_TARGET != HWY_SVE && \ - HWY_TARGET != HWY_SVE2 && HWY_TARGET != HWY_SVE_256 && \ - HWY_TARGET != HWY_SVE2_128 +#if HWY_TARGET != HWY_RVV && !HWY_TARGET_IS_SVE template HWY_API VFromD TwoTablesLookupLanes(D /*d*/, VFromD a, VFromD b, IndicesFromD idx) { @@ -3947,9 +6297,9 @@ HWY_API VFromD Reverse2(D d, VFromD v) { const Repartition du16; return BitCast(d, RotateRight<8>(BitCast(du16, v))); #else - alignas(16) static constexpr TFromD kShuffle[16] = { - 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14}; - return TableLookupBytes(v, LoadDup128(d, kShuffle)); + const VFromD shuffle = Dup128VecFromValues(d, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, + 11, 10, 13, 12, 15, 14); + return TableLookupBytes(v, shuffle); #endif } @@ -3959,10 +6309,10 @@ HWY_API VFromD Reverse4(D d, VFromD v) { const Repartition du16; return BitCast(d, Reverse2(du16, BitCast(du16, Reverse2(d, v)))); #else - alignas(16) static constexpr uint8_t kShuffle[16] = { - 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12}; const Repartition du8; - return TableLookupBytes(v, BitCast(d, LoadDup128(du8, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + du8, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12); + return TableLookupBytes(v, BitCast(d, shuffle)); #endif } @@ -3972,10 +6322,10 @@ HWY_API VFromD Reverse8(D d, VFromD v) { const Repartition du32; return BitCast(d, Reverse2(du32, BitCast(du32, Reverse4(d, v)))); #else - alignas(16) static constexpr uint8_t kShuffle[16] = { - 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; const Repartition du8; - return TableLookupBytes(v, BitCast(d, LoadDup128(du8, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + du8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8); + return TableLookupBytes(v, BitCast(d, shuffle)); #endif } @@ -4103,7 +6453,7 @@ HWY_API V ReverseBits(V v) { #define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 #endif -#if HWY_TARGET != HWY_SCALAR +#if HWY_TARGET != HWY_SCALAR || HWY_IDE namespace detail { template @@ -4111,15 +6461,13 @@ HWY_INLINE Vec Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, const uint32_t x2, const uint32_t x1, const uint32_t x0) { - alignas(16) const uint32_t lanes[4] = {x0, x1, x2, x3}; - #if HWY_TARGET == HWY_RVV constexpr int kPow2 = d.Pow2(); constexpr int kLoadPow2 = HWY_MAX(kPow2, -1); const ScalableTag d_load; #else constexpr size_t kMaxBytes = d.MaxBytes(); -#if HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_NEON_WITHOUT_AES +#if HWY_TARGET_IS_NEON constexpr size_t kMinLanesToLoad = 2; #else constexpr size_t kMinLanesToLoad = 4; @@ -4128,8 +6476,7 @@ HWY_INLINE Vec Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, HWY_MAX(kMaxBytes / sizeof(uint32_t), kMinLanesToLoad); const CappedTag d_load; #endif - - return ResizeBitCast(d, LoadDup128(d_load, lanes)); + return ResizeBitCast(d, Dup128VecFromValues(d_load, x0, x1, x2, x3)); } } // namespace detail @@ -4137,7 +6484,7 @@ HWY_INLINE Vec Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, #endif // HWY_NATIVE_PER4LANEBLKSHUF_DUP32 -#if HWY_TARGET != HWY_SCALAR +#if HWY_TARGET != HWY_SCALAR || HWY_IDE namespace detail { template @@ -4189,8 +6536,7 @@ HWY_INLINE Vec TblLookupPer4LaneBlkU8IdxInBlk(D d, const uint32_t idx3, d, Set(du32, U8x4Per4LaneBlkIndices(idx3, idx2, idx1, idx0))); } -#if HWY_HAVE_SCALABLE || HWY_TARGET == HWY_SVE_256 || \ - HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_EMU128 +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_EMU128 #define HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE(D) void* = nullptr #else #define HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE(D) HWY_IF_T_SIZE_D(D, 8) @@ -4291,19 +6637,16 @@ HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, const uint16_t u16_idx1 = static_cast(idx1); const uint16_t u16_idx2 = static_cast(idx2); const uint16_t u16_idx3 = static_cast(idx3); - alignas(16) - const uint16_t indices[8] = {u16_idx0, u16_idx1, u16_idx2, u16_idx3, - u16_idx0, u16_idx1, u16_idx2, u16_idx3}; - -#if HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_NEON_WITHOUT_AES +#if HWY_TARGET_IS_NEON constexpr size_t kMinLanesToLoad = 4; #else constexpr size_t kMinLanesToLoad = 8; #endif constexpr size_t kNumToLoad = HWY_MAX(HWY_MAX_LANES_D(D), kMinLanesToLoad); const CappedTag d_load; - - return ResizeBitCast(d, LoadDup128(d_load, indices)); + return ResizeBitCast( + d, Dup128VecFromValues(d_load, u16_idx0, u16_idx1, u16_idx2, u16_idx3, + u16_idx0, u16_idx1, u16_idx2, u16_idx3)); } template @@ -4524,7 +6867,7 @@ HWY_API V Per4LaneBlockShuffle(V v) { return v; } -#if HWY_TARGET != HWY_SCALAR +#if HWY_TARGET != HWY_SCALAR || HWY_IDE template , 2)> HWY_API V Per4LaneBlockShuffle(V v) { @@ -4623,7 +6966,7 @@ HWY_API VFromD Slide1Down(D d, VFromD /*v*/) { return Zero(d); } -#if HWY_TARGET != HWY_SCALAR +#if HWY_TARGET != HWY_SCALAR || HWY_IDE template HWY_API VFromD Slide1Up(D d, VFromD v) { return ShiftLeftLanes<1>(d, v); @@ -4672,6 +7015,290 @@ HWY_API VFromD SlideDownBlocks(D d, VFromD v) { } #endif +// ------------------------------ Slide mask up/down +#if (defined(HWY_NATIVE_SLIDE_MASK) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_SLIDE_MASK +#undef HWY_NATIVE_SLIDE_MASK +#else +#define HWY_NATIVE_SLIDE_MASK +#endif + +template +HWY_API Mask SlideMask1Up(D d, Mask m) { + return MaskFromVec(Slide1Up(d, VecFromMask(d, m))); +} + +template +HWY_API Mask SlideMask1Down(D d, Mask m) { + return MaskFromVec(Slide1Down(d, VecFromMask(d, m))); +} + +template +HWY_API Mask SlideMaskUpLanes(D d, Mask m, size_t amt) { + return MaskFromVec(SlideUpLanes(d, VecFromMask(d, m), amt)); +} + +template +HWY_API Mask SlideMaskDownLanes(D d, Mask m, size_t amt) { + return MaskFromVec(SlideDownLanes(d, VecFromMask(d, m), amt)); +} + +#endif // HWY_NATIVE_SLIDE_MASK + +// ------------------------------ SumsOfAdjQuadAbsDiff + +#if (defined(HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API Vec>> SumsOfAdjQuadAbsDiff(V8 a, V8 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + using D8 = DFromV; + const D8 d8; + const RebindToUnsigned du8; + const RepartitionToWide d16; + const RepartitionToWide du16; + + // Ensure that a is resized to a vector that has at least + // HWY_MAX(Lanes(d8), size_t{8} << kAOffset) lanes for the interleave and + // CombineShiftRightBytes operations below. +#if HWY_TARGET == HWY_RVV + // On RVV targets, need to ensure that d8_interleave.Pow2() >= 0 is true + // to ensure that Lanes(d8_interleave) >= 16 is true. + + // Lanes(d8_interleave) >= Lanes(d8) is guaranteed to be true on RVV + // targets as d8_interleave.Pow2() >= d8.Pow2() is true. + constexpr int kInterleavePow2 = HWY_MAX(d8.Pow2(), 0); + const ScalableTag, kInterleavePow2> d8_interleave; +#elif HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // On SVE targets, Lanes(d8_interleave) >= 16 and + // Lanes(d8_interleave) >= Lanes(d8) are both already true as d8 is a SIMD + // tag for a full u8/i8 vector on SVE. + const D8 d8_interleave; +#else + // On targets that use non-scalable vector types, Lanes(d8_interleave) is + // equal to HWY_MAX(Lanes(d8), size_t{8} << kAOffset). + constexpr size_t kInterleaveLanes = + HWY_MAX(HWY_MAX_LANES_D(D8), size_t{8} << kAOffset); + const FixedTag, kInterleaveLanes> d8_interleave; +#endif + + // The ResizeBitCast operation below will resize a to a vector that has + // at least HWY_MAX(Lanes(d8), size_t{8} << kAOffset) lanes for the + // InterleaveLower, InterleaveUpper, and CombineShiftRightBytes operations + // below. + const auto a_to_interleave = ResizeBitCast(d8_interleave, a); + + const auto a_interleaved_lo = + InterleaveLower(d8_interleave, a_to_interleave, a_to_interleave); + const auto a_interleaved_hi = + InterleaveUpper(d8_interleave, a_to_interleave, a_to_interleave); + + /* a01: { a[kAOffset*4+0], a[kAOffset*4+1], a[kAOffset*4+1], a[kAOffset*4+2], + a[kAOffset*4+2], a[kAOffset*4+3], a[kAOffset*4+3], a[kAOffset*4+4], + a[kAOffset*4+4], a[kAOffset*4+5], a[kAOffset*4+5], a[kAOffset*4+6], + a[kAOffset*4+6], a[kAOffset*4+7], a[kAOffset*4+7], a[kAOffset*4+8] } + */ + /* a23: { a[kAOffset*4+2], a[kAOffset*4+3], a[kAOffset*4+3], a[kAOffset*4+4], + a[kAOffset*4+4], a[kAOffset*4+5], a[kAOffset*4+5], a[kAOffset*4+6], + a[kAOffset*4+6], a[kAOffset*4+7], a[kAOffset*4+7], a[kAOffset*4+8], + a[kAOffset*4+8], a[kAOffset*4+9], a[kAOffset*4+9], a[kAOffset*4+10] + } */ + + // a01 and a23 are resized back to V8 as only the first Lanes(d8) lanes of + // the CombineShiftRightBytes are needed for the subsequent AbsDiff operations + // and as a01 and a23 need to be the same vector type as b01 and b23 for the + // AbsDiff operations below. + const V8 a01 = + ResizeBitCast(d8, CombineShiftRightBytes( + d8_interleave, a_interleaved_hi, a_interleaved_lo)); + const V8 a23 = + ResizeBitCast(d8, CombineShiftRightBytes( + d8_interleave, a_interleaved_hi, a_interleaved_lo)); + + /* b01: { b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1] } + */ + /* b23: { b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3] } + */ + const V8 b01 = BitCast(d8, Broadcast(BitCast(d16, b))); + const V8 b23 = BitCast(d8, Broadcast(BitCast(d16, b))); + + const VFromD absdiff_sum_01 = + SumsOf2(BitCast(du8, AbsDiff(a01, b01))); + const VFromD absdiff_sum_23 = + SumsOf2(BitCast(du8, AbsDiff(a23, b23))); + return BitCast(d16, Add(absdiff_sum_01, absdiff_sum_23)); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if (defined(HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API Vec>> SumsOfShuffledQuadAbsDiff(V8 a, + V8 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + +#if HWY_TARGET == HWY_RVV + // On RVV, ensure that both vA and vB have a LMUL of at least 1/2 so that + // both vA and vB can be bitcasted to a u32 vector. + const detail::AdjustSimdTagToMinVecPow2< + RepartitionToWideX2>> + d32; + const RepartitionToNarrow d16; + const RepartitionToNarrow d8; + + const auto vA = ResizeBitCast(d8, a); + const auto vB = ResizeBitCast(d8, b); +#else + const DFromV d8; + const RepartitionToWide d16; + const RepartitionToWide d32; + + const auto vA = a; + const auto vB = b; +#endif + + const RebindToUnsigned du8; + + const auto a_shuf = + Per4LaneBlockShuffle(BitCast(d32, vA)); + /* a0123_2345: { a_shuf[0], a_shuf[1], a_shuf[2], a_shuf[3], + a_shuf[2], a_shuf[3], a_shuf[4], a_shuf[5], + a_shuf[8], a_shuf[9], a_shuf[10], a_shuf[11], + a_shuf[10], a_shuf[11], a_shuf[12], a_shuf[13] } */ + /* a1234_3456: { a_shuf[1], a_shuf[2], a_shuf[3], a_shuf[4], + a_shuf[3], a_shuf[4], a_shuf[5], a_shuf[6], + a_shuf[9], a_shuf[10], a_shuf[11], a_shuf[12], + a_shuf[11], a_shuf[12], a_shuf[13], a_shuf[14] } */ +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // On RVV/SVE targets, use Slide1Up/Slide1Down instead of + // ShiftLeftBytes/ShiftRightBytes to avoid unnecessary zeroing out of any + // lanes that are shifted into an adjacent 16-byte block as any lanes that are + // shifted into an adjacent 16-byte block by Slide1Up/Slide1Down will be + // replaced by the OddEven operation. + const auto a_0123_2345 = BitCast( + d8, OddEven(BitCast(d32, Slide1Up(d16, BitCast(d16, a_shuf))), a_shuf)); + const auto a_1234_3456 = + BitCast(d8, OddEven(BitCast(d32, Slide1Up(d8, BitCast(d8, a_shuf))), + BitCast(d32, Slide1Down(d8, BitCast(d8, a_shuf))))); +#else + const auto a_0123_2345 = + BitCast(d8, OddEven(ShiftLeftBytes<2>(d32, a_shuf), a_shuf)); + const auto a_1234_3456 = BitCast( + d8, + OddEven(ShiftLeftBytes<1>(d32, a_shuf), ShiftRightBytes<1>(d32, a_shuf))); +#endif + + auto even_sums = SumsOf4(BitCast(du8, AbsDiff(a_0123_2345, vB))); + auto odd_sums = SumsOf4(BitCast(du8, AbsDiff(a_1234_3456, vB))); + +#if HWY_IS_LITTLE_ENDIAN + odd_sums = ShiftLeft<16>(odd_sums); +#else + even_sums = ShiftLeft<16>(even_sums); +#endif + + const auto sums = OddEven(BitCast(d16, odd_sums), BitCast(d16, even_sums)); + +#if HWY_TARGET == HWY_RVV + return ResizeBitCast(RepartitionToWide>(), sums); +#else + return sums; +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF + +// ------------------------------ BitShuffle (Rol) +#if (defined(HWY_NATIVE_BITSHUFFLE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +#if HWY_HAVE_INTEGER64 && HWY_TARGET != HWY_SCALAR +template ), HWY_IF_UI8(TFromV)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Repartition du8; + +#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 + const Repartition d_idx_shr; +#else + const Repartition d_idx_shr; +#endif + +#if HWY_IS_LITTLE_ENDIAN + constexpr uint64_t kExtractedBitsMask = + static_cast(0x8040201008040201u); +#else + constexpr uint64_t kExtractedBitsMask = + static_cast(0x0102040810204080u); +#endif + + const auto k7 = Set(du8, uint8_t{0x07}); + + auto unmasked_byte_idx = BitCast(du8, ShiftRight<3>(BitCast(d_idx_shr, idx))); +#if HWY_IS_BIG_ENDIAN + // Need to invert the lower 3 bits of unmasked_byte_idx[i] on big-endian + // targets + unmasked_byte_idx = Xor(unmasked_byte_idx, k7); +#endif // HWY_IS_BIG_ENDIAN + + const auto byte_idx = BitwiseIfThenElse( + k7, unmasked_byte_idx, + BitCast(du8, Dup128VecFromValues(du64, uint64_t{0}, + uint64_t{0x0808080808080808u}))); + // We want to shift right by idx & 7 to extract the desired bit in `bytes`, + // and left by iota & 7 to put it in the correct output bit. To correctly + // handle shift counts from -7 to 7, we rotate. + const auto rotate_left_bits = Sub(Iota(du8, uint8_t{0}), BitCast(du8, idx)); + + const auto extracted_bits = + And(Rol(TableLookupBytes(v, byte_idx), rotate_left_bits), + BitCast(du8, Set(du64, kExtractedBitsMask))); + // Combine bit-sliced (one bit per byte) into one 64-bit sum. + return BitCast(d64, SumsOf8(extracted_bits)); +} +#endif // HWY_HAVE_INTEGER64 && HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_BITSHUFFLE + // ================================================== Operator wrapper // SVE* and RVV currently cannot define operators and have already defined @@ -4700,6 +7327,10 @@ template HWY_API V Div(V a, V b) { return a / b; } +template +HWY_API V Mod(V a, V b) { + return a % b; +} template V Shl(V a, V b) { @@ -4739,6 +7370,8 @@ HWY_API auto Le(V a, V b) -> decltype(a == b) { #endif // HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_GENERIC_IF_EMULATED_D + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/r/src/vendor/highway/hwy/ops/inside-inl.h b/r/src/vendor/highway/hwy/ops/inside-inl.h new file mode 100644 index 00000000..07759afe --- /dev/null +++ b/r/src/vendor/highway/hwy/ops/inside-inl.h @@ -0,0 +1,691 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Must be included inside an existing include guard, with the following ops +// already defined: BitCast, And, Set, ShiftLeft, ShiftRight, PromoteLowerTo, +// ConcatEven, ConcatOdd, plus the optional detail::PromoteEvenTo and +// detail::PromoteOddTo (if implemented in the target-specific header). + +// This is normally set by set_macros-inl.h before this header is included; +// if not, we are viewing this header standalone. Reduce IDE errors by: +#if !defined(HWY_NAMESPACE) +// 1) Defining HWY_IDE so we get syntax highlighting rather than all-gray text. +#include "hwy/ops/shared-inl.h" +// 2) Entering the HWY_NAMESPACE to make definitions from shared-inl.h visible. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +#define HWY_INSIDE_END_NAMESPACE +// 3) Providing a dummy VFromD (usually done by the target-specific header). +template +using VFromD = int; +template +using TFromV = int; +template +struct DFromV {}; +#endif + +// ------------------------------ Vec/Create/Get/Set2..4 + +// On SVE and RVV, Vec2..4 are aliases to built-in types. Also exclude the +// fixed-size SVE targets. +#if HWY_IDE || (!HWY_HAVE_SCALABLE && !HWY_TARGET_IS_SVE) + +// NOTE: these are used inside arm_neon-inl.h, hence they cannot be defined in +// generic_ops-inl.h, which is included after that. +template +struct Vec2 { + VFromD v0; + VFromD v1; +}; + +template +struct Vec3 { + VFromD v0; + VFromD v1; + VFromD v2; +}; + +template +struct Vec4 { + VFromD v0; + VFromD v1; + VFromD v2; + VFromD v3; +}; + +// D arg is unused but allows deducing D. +template +HWY_API Vec2 Create2(D /* tag */, VFromD v0, VFromD v1) { + return Vec2{v0, v1}; +} + +template +HWY_API Vec3 Create3(D /* tag */, VFromD v0, VFromD v1, VFromD v2) { + return Vec3{v0, v1, v2}; +} + +template +HWY_API Vec4 Create4(D /* tag */, VFromD v0, VFromD v1, VFromD v2, + VFromD v3) { + return Vec4{v0, v1, v2, v3}; +} + +template +HWY_API VFromD Get2(Vec2 tuple) { + static_assert(kIndex < 2, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 : tuple.v1; +} + +template +HWY_API VFromD Get3(Vec3 tuple) { + static_assert(kIndex < 3, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 : kIndex == 1 ? tuple.v1 : tuple.v2; +} + +template +HWY_API VFromD Get4(Vec4 tuple) { + static_assert(kIndex < 4, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 + : kIndex == 1 ? tuple.v1 + : kIndex == 2 ? tuple.v2 + : tuple.v3; +} + +template +HWY_API Vec2 Set2(Vec2 tuple, VFromD val) { + static_assert(kIndex < 2, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else { + tuple.v1 = val; + } + return tuple; +} + +template +HWY_API Vec3 Set3(Vec3 tuple, VFromD val) { + static_assert(kIndex < 3, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else if (kIndex == 1) { + tuple.v1 = val; + } else { + tuple.v2 = val; + } + return tuple; +} + +template +HWY_API Vec4 Set4(Vec4 tuple, VFromD val) { + static_assert(kIndex < 4, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else if (kIndex == 1) { + tuple.v1 = val; + } else if (kIndex == 2) { + tuple.v2 = val; + } else { + tuple.v3 = val; + } + return tuple; +} + +#endif // !HWY_HAVE_SCALABLE || HWY_IDE + +// ------------------------------ Rol/Ror (And, Or, Neg, Shl, Shr) +#if (defined(HWY_NATIVE_ROL_ROR_8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint8_t{7}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint8_t{7}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_8 + +#if (defined(HWY_NATIVE_ROL_ROR_16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint16_t{15}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint16_t{15}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_16 + +#if (defined(HWY_NATIVE_ROL_ROR_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint32_t{31}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint32_t{31}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint64_t{63}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint64_t{63}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} +#endif // HWY_HAVE_INTEGER64 + +#endif // HWY_NATIVE_ROL_ROR_32_64 + +// ------------------------------ RotateLeftSame/RotateRightSame + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 7; + const int shr_amt = static_cast((0u - static_cast(bits)) & 7u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 7; + const int shl_amt = static_cast((0u - static_cast(bits)) & 7u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_SAME_8 + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 15; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 15u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 15; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 15u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} +#endif // HWY_NATIVE_ROL_ROR_SAME_16 + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 31; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 31u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 31; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 31u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 63; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 63u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 63; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 63u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} +#endif // HWY_HAVE_INTEGER64 + +#endif // HWY_NATIVE_ROL_ROR_SAME_32_64 + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +// These are used by target-specific headers for ReorderWidenMulAccumulate etc. + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +// Tag dispatch is used in detail::PromoteEvenTo and detail::PromoteOddTo as +// there are target-specific specializations for some of the +// detail::PromoteEvenTo and detail::PromoteOddTo cases on +// SVE/PPC/SSE2/SSSE3/SSE4/AVX2. + +// All targets except HWY_SCALAR use the implementations of +// detail::PromoteEvenTo and detail::PromoteOddTo in generic_ops-inl.h for at +// least some of the PromoteEvenTo and PromoteOddTo cases. + +// Signed to signed PromoteEvenTo/PromoteOddTo +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_TARGET_IS_SVE + // The intrinsic expects the wide lane type. + return NativePromoteEvenTo(BitCast(d_to, v)); +#else +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, need to shift each lane of the bitcasted + // vector left by kToLaneSize * 4 bits to get the bits of the even + // source lanes into the upper kToLaneSize * 4 bits of even_in_hi. + const auto even_in_hi = ShiftLeft(BitCast(d_to, v)); +#else + // On big-endian targets, the bits of the even source lanes are already + // in the upper kToLaneSize * 4 bits of the lanes of the bitcasted + // vector. + const auto even_in_hi = BitCast(d_to, v); +#endif + + // Right-shift even_in_hi by kToLaneSize * 4 bits + return ShiftRight(even_in_hi); +#endif // HWY_TARGET_IS_SVE +} + +// Unsigned to unsigned PromoteEvenTo/PromoteOddTo +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_TARGET_IS_SVE + // The intrinsic expects the wide lane type. + return NativePromoteEvenTo(BitCast(d_to, v)); +#else +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, the bits of the even source lanes are already + // in the lower kToLaneSize * 4 bits of the lanes of the bitcasted vector. + + // Simply need to zero out the upper bits of each lane of the bitcasted + // vector. + return And(BitCast(d_to, v), + Set(d_to, static_cast>(LimitsMax>()))); +#else + // On big-endian targets, need to shift each lane of the bitcasted vector + // right by kToLaneSize * 4 bits to get the bits of the even source lanes into + // the lower kToLaneSize * 4 bits of the result. + + // The right shift below will zero out the upper kToLaneSize * 4 bits of the + // result. + return ShiftRight(BitCast(d_to, v)); +#endif +#endif // HWY_TARGET_IS_SVE +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, the bits of the odd source lanes are already in + // the upper kToLaneSize * 4 bits of the lanes of the bitcasted vector. + const auto odd_in_hi = BitCast(d_to, v); +#else + // On big-endian targets, need to shift each lane of the bitcasted vector + // left by kToLaneSize * 4 bits to get the bits of the odd source lanes into + // the upper kToLaneSize * 4 bits of odd_in_hi. + const auto odd_in_hi = ShiftLeft(BitCast(d_to, v)); +#endif + + // Right-shift odd_in_hi by kToLaneSize * 4 bits + return ShiftRight(odd_in_hi); +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, need to shift each lane of the bitcasted vector + // right by kToLaneSize * 4 bits to get the bits of the odd source lanes into + // the lower kToLaneSize * 4 bits of the result. + + // The right shift below will zero out the upper kToLaneSize * 4 bits of the + // result. + return ShiftRight(BitCast(d_to, v)); +#else + // On big-endian targets, the bits of the even source lanes are already + // in the lower kToLaneSize * 4 bits of the lanes of the bitcasted vector. + + // Simply need to zero out the upper bits of each lane of the bitcasted + // vector. + return And(BitCast(d_to, v), + Set(d_to, static_cast>(LimitsMax>()))); +#endif +} + +// Unsigned to signed: Same as unsigned->unsigned PromoteEvenTo/PromoteOddTo +// followed by BitCast to signed +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { + const RebindToUnsigned du_to; + return BitCast(d_to, + PromoteEvenTo(hwy::UnsignedTag(), hwy::SizeTag(), + hwy::UnsignedTag(), du_to, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { + const RebindToUnsigned du_to; + return BitCast(d_to, + PromoteOddTo(hwy::UnsignedTag(), hwy::SizeTag(), + hwy::UnsignedTag(), du_to, v)); +} + +// BF16->F32 PromoteEvenTo + +// NOTE: It is possible for FromTypeTag to be hwy::SignedTag or hwy::UnsignedTag +// instead of hwy::FloatTag on targets that use scalable vectors. + +// VBF16 is considered to be a bfloat16_t vector if TFromV is the same +// type as TFromV>> + +// The BF16->F32 PromoteEvenTo overload is only enabled if VBF16 is considered +// to be a bfloat16_t vector. +template >, + hwy::EnableIf, TFromV>()>* = nullptr> +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, DF32 d_to, + VBF16 v) { + const RebindToUnsigned du_to; +#if HWY_IS_LITTLE_ENDIAN + // On little-endian platforms, need to shift left each lane of the bitcasted + // vector by 16 bits. + return BitCast(d_to, ShiftLeft<16>(BitCast(du_to, v))); +#else + // On big-endian platforms, the even lanes of the source vector are already + // in the upper 16 bits of the lanes of the bitcasted vector. + + // Need to simply zero out the lower 16 bits of each lane of the bitcasted + // vector. + return BitCast(d_to, + And(BitCast(du_to, v), Set(du_to, uint32_t{0xFFFF0000u}))); +#endif +} + +// BF16->F32 PromoteOddTo + +// NOTE: It is possible for FromTypeTag to be hwy::SignedTag or hwy::UnsignedTag +// instead of hwy::FloatTag on targets that use scalable vectors. + +// VBF16 is considered to be a bfloat16_t vector if TFromV is the same +// type as TFromV>> + +// The BF16->F32 PromoteEvenTo overload is only enabled if VBF16 is considered +// to be a bfloat16_t vector. +template >, + hwy::EnableIf, TFromV>()>* = nullptr> +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, DF32 d_to, + VBF16 v) { + const RebindToUnsigned du_to; +#if HWY_IS_LITTLE_ENDIAN + // On little-endian platforms, the odd lanes of the source vector are already + // in the upper 16 bits of the lanes of the bitcasted vector. + + // Need to simply zero out the lower 16 bits of each lane of the bitcasted + // vector. + return BitCast(d_to, + And(BitCast(du_to, v), Set(du_to, uint32_t{0xFFFF0000u}))); +#else + // On big-endian platforms, need to shift left each lane of the bitcasted + // vector by 16 bits. + return BitCast(d_to, ShiftLeft<16>(BitCast(du_to, v))); +#endif +} + +// Default PromoteEvenTo/PromoteOddTo implementations +template +HWY_INLINE VFromD PromoteEvenTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + return PromoteLowerTo(d_to, v); +} + +template +HWY_INLINE VFromD PromoteEvenTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const DFromV d; + return PromoteLowerTo(d_to, ConcatEven(d, v, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const DFromV d; + return PromoteLowerTo(d_to, ConcatOdd(d, v, v)); +} + +} // namespace detail + +template )), + class V2 = VFromD, D>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(V2))> +HWY_API VFromD PromoteEvenTo(D d, V v) { + return detail::PromoteEvenTo(hwy::TypeTag>(), + hwy::SizeTag)>(), + hwy::TypeTag>(), d, v); +} + +template )), + class V2 = VFromD, D>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(V2))> +HWY_API VFromD PromoteOddTo(D d, V v) { + return detail::PromoteOddTo(hwy::TypeTag>(), + hwy::SizeTag)>(), + hwy::TypeTag>(), d, v); +} +#endif // HWY_TARGET != HWY_SCALAR + +#ifdef HWY_INSIDE_END_NAMESPACE +#undef HWY_INSIDE_END_NAMESPACE +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); +#endif diff --git a/r/src/vendor/highway/hwy/ops/ppc_vsx-inl.h b/r/src/vendor/highway/hwy/ops/ppc_vsx-inl.h index d2589285..d216c548 100644 --- a/r/src/vendor/highway/hwy/ops/ppc_vsx-inl.h +++ b/r/src/vendor/highway/hwy/ops/ppc_vsx-inl.h @@ -13,9 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -// 128-bit vectors for VSX +// 128-bit vectors for VSX/Z14 // External include guard in highway.h - see comment there. +#if HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 +#define HWY_S390X_HAVE_Z14 1 +#else +#define HWY_S390X_HAVE_Z14 0 +#endif + #pragma push_macro("vector") #pragma push_macro("pixel") #pragma push_macro("bool") @@ -24,7 +30,11 @@ #undef pixel #undef bool +#if HWY_S390X_HAVE_Z14 +#include +#else #include +#endif #pragma pop_macro("vector") #pragma pop_macro("pixel") @@ -37,20 +47,26 @@ // This means we can only use POWER10-specific intrinsics in static dispatch // mode (where the -mpower10-vector compiler flag is passed). Same for PPC9. // On other compilers, the usual target check is sufficient. -#if HWY_TARGET <= HWY_PPC9 && \ +#if !HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_PPC9 && \ (defined(_ARCH_PWR9) || defined(__POWER9_VECTOR__)) #define HWY_PPC_HAVE_9 1 #else #define HWY_PPC_HAVE_9 0 #endif -#if HWY_TARGET <= HWY_PPC10 && \ +#if !HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_PPC10 && \ (defined(_ARCH_PWR10) || defined(__POWER10_VECTOR__)) #define HWY_PPC_HAVE_10 1 #else #define HWY_PPC_HAVE_10 0 #endif +#if HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_Z15 && __ARCH__ >= 13 +#define HWY_S390X_HAVE_Z15 1 +#else +#define HWY_S390X_HAVE_Z15 0 +#endif + HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { @@ -125,6 +141,9 @@ class Vec128 { HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } @@ -180,9 +199,6 @@ HWY_API Vec128 Zero(D /* tag */) { template using VFromD = decltype(Zero(D())); -// ------------------------------ Tuple (VFromD) -#include "hwy/ops/tuple-inl.h" - // ------------------------------ BitCast template @@ -215,6 +231,12 @@ HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{vec_splats(static_cast(t))}; } +template )> +HWY_API VFromD Set(D d, TFromD t) { + const RebindToUnsigned du; + return BitCast(d, Set(du, BitCastScalar>(t))); +} + // Returns a vector with uninitialized elements. template HWY_API VFromD Undefined(D d) { @@ -222,6 +244,8 @@ HWY_API VFromD Undefined(D d) { // Suppressing maybe-uninitialized both here and at the caller does not work, // so initialize. return Zero(d); +#elif HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + return VFromD{__builtin_nondeterministic_value(Zero(d).raw)}; #else HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") @@ -240,6 +264,58 @@ HWY_API T GetLane(Vec128 v) { return static_cast(v.raw[0]); } +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const typename detail::Raw128>::type raw = { + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const typename detail::Raw128>::type raw = {t0, t1, t2, t3, + t4, t5, t6, t7}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToUnsigned du; + return BitCast( + d, Dup128VecFromValues( + du, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const typename detail::Raw128>::type raw = {t0, t1, t2, t3}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + const typename detail::Raw128>::type raw = {t0, t1}; + return VFromD{raw}; +} + // ================================================== LOGICAL // ------------------------------ And @@ -249,7 +325,11 @@ HWY_API Vec128 And(Vec128 a, Vec128 b) { const DFromV d; const RebindToUnsigned du; using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw & BitCast(du, b).raw}); +#else return BitCast(d, VU{vec_and(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif } // ------------------------------ AndNot @@ -271,7 +351,11 @@ HWY_API Vec128 Or(Vec128 a, Vec128 b) { const DFromV d; const RebindToUnsigned du; using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw | BitCast(du, b).raw}); +#else return BitCast(d, VU{vec_or(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif } // ------------------------------ Xor @@ -281,7 +365,11 @@ HWY_API Vec128 Xor(Vec128 a, Vec128 b) { const DFromV d; const RebindToUnsigned du; using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw ^ BitCast(du, b).raw}); +#else return BitCast(d, VU{vec_xor(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif } // ------------------------------ Not @@ -476,9 +564,21 @@ HWY_API Vec128 operator^(Vec128 a, Vec128 b) { // ------------------------------ Neg -template -HWY_INLINE Vec128 Neg(Vec128 v) { +template +HWY_API Vec128 Neg(Vec128 v) { + // If T is an signed integer type, use Zero(d) - v instead of vec_neg to + // avoid undefined behavior in the case where v[i] == LimitsMin() + const DFromV d; + return Zero(d) - v; +} + +template +HWY_API Vec128 Neg(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return Xor(v, SignBit(DFromV())); +#else return Vec128{vec_neg(v.raw)}; +#endif } template @@ -489,13 +589,40 @@ HWY_API Vec128 Neg(const Vec128 v) { // ------------------------------ Abs // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. -template +template +HWY_API Vec128 Abs(Vec128 v) { + // If T is a signed integer type, use Max(v, Neg(v)) instead of vec_abs to + // avoid undefined behavior in the case where v[i] == LimitsMin(). + return Max(v, Neg(v)); +} + +template HWY_API Vec128 Abs(Vec128 v) { return Vec128{vec_abs(v.raw)}; } // ------------------------------ CopySign +#if HWY_S390X_HAVE_Z14 +template +HWY_API V CopySign(const V magn, const V sign) { + static_assert(IsFloat>(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + return BitwiseIfThenElse(msb, sign, magn); +} +#else // VSX template HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { @@ -525,6 +652,7 @@ HWY_API Vec128 CopySign(Vec128 magn, return Vec128{vec_cpsgn(sign.raw, magn.raw)}; #endif } +#endif // HWY_S390X_HAVE_Z14 template HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { @@ -542,10 +670,21 @@ HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { template > HWY_API Vec128 Load(D /* tag */, const T* HWY_RESTRICT aligned) { +// Suppress the ignoring attributes warning that is generated by +// HWY_RCAST_ALIGNED(const LoadRaw*, aligned) with GCC +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") +#endif + using LoadRaw = typename detail::Raw128::AlignedRawVec; - const LoadRaw* HWY_RESTRICT p = reinterpret_cast(aligned); + const LoadRaw* HWY_RESTRICT p = HWY_RCAST_ALIGNED(const LoadRaw*, aligned); using ResultRaw = typename detail::Raw128::type; return Vec128{reinterpret_cast(*p)}; + +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(pop) +#endif } // Any <= 64 bit @@ -598,19 +737,13 @@ HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, // mask ? yes : 0 template HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { - const DFromV d; - const RebindToUnsigned du; - return BitCast(d, - VFromD{vec_and(BitCast(du, yes).raw, mask.raw)}); + return yes & VecFromMask(DFromV(), mask); } // mask ? 0 : no template HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { - const DFromV d; - const RebindToUnsigned du; - return BitCast(d, - VFromD{vec_andc(BitCast(du, no).raw, mask.raw)}); + return AndNot(VecFromMask(DFromV(), mask), no); } // ------------------------------ Mask logical @@ -622,7 +755,11 @@ HWY_API Mask128 Not(Mask128 m) { template HWY_API Mask128 And(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw & b.raw}; +#else return Mask128{vec_and(a.raw, b.raw)}; +#endif } template @@ -632,12 +769,20 @@ HWY_API Mask128 AndNot(Mask128 a, Mask128 b) { template HWY_API Mask128 Or(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw | b.raw}; +#else return Mask128{vec_or(a.raw, b.raw)}; +#endif } template HWY_API Mask128 Xor(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw ^ b.raw}; +#else return Mask128{vec_xor(a.raw, b.raw)}; +#endif } template @@ -645,36 +790,24 @@ HWY_API Mask128 ExclusiveNeither(Mask128 a, Mask128 b) { return Mask128{vec_nor(a.raw, b.raw)}; } -// ------------------------------ BroadcastSignBit - -template -HWY_API Vec128 BroadcastSignBit(Vec128 v) { - return Vec128{ - vec_sra(v.raw, vec_splats(static_cast(7)))}; -} - -template -HWY_API Vec128 BroadcastSignBit(Vec128 v) { - return Vec128{ - vec_sra(v.raw, vec_splats(static_cast(15)))}; -} - -template -HWY_API Vec128 BroadcastSignBit(Vec128 v) { - return Vec128{vec_sra(v.raw, vec_splats(31u))}; -} - -template -HWY_API Vec128 BroadcastSignBit(Vec128 v) { - return Vec128{vec_sra(v.raw, vec_splats(63ULL))}; -} - // ------------------------------ ShiftLeftSame template HWY_API Vec128 ShiftLeftSame(Vec128 v, const int bits) { - using TU = typename detail::Raw128>::RawT; - return Vec128{vec_sl(v.raw, vec_splats(static_cast(bits)))}; + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, + VFromD{BitCast(du, v).raw + << Set(du, static_cast(bits)).raw}); +#else + // Do an unsigned vec_sl operation to avoid undefined behavior + return BitCast( + d, VFromD{ + vec_sl(BitCast(du, v).raw, Set(du, static_cast(bits)).raw)}); +#endif } // ------------------------------ ShiftRightSame @@ -682,13 +815,22 @@ HWY_API Vec128 ShiftLeftSame(Vec128 v, const int bits) { template HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { using TU = typename detail::Raw128>::RawT; +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> vec_splats(static_cast(bits))}; +#else return Vec128{vec_sr(v.raw, vec_splats(static_cast(bits)))}; +#endif } template HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { +#if HWY_S390X_HAVE_Z14 + using TI = typename detail::Raw128::RawT; + return Vec128{v.raw >> vec_splats(static_cast(bits))}; +#else using TU = typename detail::Raw128>::RawT; return Vec128{vec_sra(v.raw, vec_splats(static_cast(bits)))}; +#endif } // ------------------------------ ShiftLeft @@ -707,6 +849,13 @@ HWY_API Vec128 ShiftRight(Vec128 v) { return ShiftRightSame(v, kBits); } +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + return ShiftRightSame(v, static_cast(sizeof(T) * 8 - 1)); +} + // ================================================== SWIZZLE (1) // ------------------------------ TableLookupBytes @@ -1003,7 +1152,7 @@ HWY_API VFromD LoadDup128(D d, const T* HWY_RESTRICT p) { return LoadU(d, p); } -#if HWY_PPC_HAVE_9 +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 #ifdef HWY_NATIVE_LOAD_N #undef HWY_NATIVE_LOAD_N #else @@ -1027,11 +1176,20 @@ HWY_API VFromD LoadN(D d, const T* HWY_RESTRICT p, const size_t num_of_bytes_to_load = HWY_MIN(max_lanes_to_load, HWY_MAX_LANES_D(D)) * sizeof(TFromD); const Repartition du8; +#if HWY_S390X_HAVE_Z14 + return (num_of_bytes_to_load > 0) + ? BitCast(d, VFromD{vec_load_len( + const_cast( + reinterpret_cast(p)), + static_cast(num_of_bytes_to_load - 1))}) + : Zero(d); +#else return BitCast( d, VFromD{vec_xl_len( const_cast(reinterpret_cast(p)), num_of_bytes_to_load)}); +#endif } template > @@ -1048,18 +1206,11 @@ HWY_API VFromD LoadNOr(VFromD no, D d, const T* HWY_RESTRICT p, } #endif - const size_t num_of_bytes_to_load = - HWY_MIN(max_lanes_to_load, HWY_MAX_LANES_D(D)) * sizeof(TFromD); - const Repartition du8; - const VFromD v = BitCast( - d, - VFromD{vec_xl_len( - const_cast(reinterpret_cast(p)), - num_of_bytes_to_load)}); - return IfThenElse(FirstN(d, max_lanes_to_load), v, no); + return IfThenElse(FirstN(d, max_lanes_to_load), + LoadN(d, p, max_lanes_to_load), no); } -#endif // HWY_PPC_HAVE_9 +#endif // HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 // Returns a vector with lane i=[0, N) set to "first" + i. namespace detail { @@ -1134,8 +1285,19 @@ HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, template > HWY_API void Store(Vec128 v, D /* tag */, T* HWY_RESTRICT aligned) { +// Suppress the ignoring attributes warning that is generated by +// HWY_RCAST_ALIGNED(StoreRaw*, aligned) with GCC +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") +#endif + using StoreRaw = typename detail::Raw128::AlignedRawVec; - *reinterpret_cast(aligned) = reinterpret_cast(v.raw); + *HWY_RCAST_ALIGNED(StoreRaw*, aligned) = reinterpret_cast(v.raw); + +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(pop) +#endif } template > @@ -1159,7 +1321,7 @@ HWY_API void StoreU(VFromD v, D d, T* HWY_RESTRICT p) { Store(v, d, p); } -#if HWY_PPC_HAVE_9 +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 #ifdef HWY_NATIVE_STORE_N #undef HWY_NATIVE_STORE_N @@ -1185,8 +1347,15 @@ HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, const size_t num_of_bytes_to_store = HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)) * sizeof(TFromD); const Repartition du8; +#if HWY_S390X_HAVE_Z14 + if (num_of_bytes_to_store > 0) { + vec_store_len(BitCast(du8, v).raw, reinterpret_cast(p), + static_cast(num_of_bytes_to_store - 1)); + } +#else vec_xst_len(BitCast(du8, v).raw, reinterpret_cast(p), num_of_bytes_to_store); +#endif } #endif @@ -1195,180 +1364,104 @@ HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, template HWY_API void BlendedStore(VFromD v, MFromD m, D d, TFromD* HWY_RESTRICT p) { - const RebindToSigned di; // for testing mask if T=bfloat16_t. - using TI = TFromD; - alignas(16) TI buf[MaxLanes(d)]; - alignas(16) TI mask[MaxLanes(d)]; - Store(BitCast(di, v), di, buf); - Store(BitCast(di, VecFromMask(d, m)), di, mask); - for (size_t i = 0; i < MaxLanes(d); ++i) { - if (mask[i]) { - CopySameSize(buf + i, p + i); - } - } + const VFromD old = LoadU(d, p); + StoreU(IfThenElse(RebindMask(d, m), v, old), d, p); } // ================================================== ARITHMETIC +namespace detail { +// If TFromD is an integer type, detail::RebindToUnsignedIfNotFloat +// rebinds D to MakeUnsigned>. + +// Otherwise, if TFromD is a floating-point type (including F16 and BF16), +// detail::RebindToUnsignedIfNotFloat is the same as D. +template +using RebindToUnsignedIfNotFloat = + hwy::If<(!hwy::IsFloat>() && !hwy::IsSpecialFloat>()), + RebindToUnsigned, D>; +} // namespace detail + // ------------------------------ Addition template HWY_API Vec128 operator+(Vec128 a, Vec128 b) { - return Vec128{vec_add(a.raw, b.raw)}; + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_add to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw + + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_add( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif } // ------------------------------ Subtraction template HWY_API Vec128 operator-(Vec128 a, Vec128 b) { - return Vec128{vec_sub(a.raw, b.raw)}; -} - -// ------------------------------ SumsOf8 -namespace detail { - -// Casts nominally int32_t result to D. -template -HWY_INLINE VFromD AltivecVsum4sbs(D d, __vector signed char a, - __vector signed int b) { - const Repartition di32; -#ifdef __OPTIMIZE__ - if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { - const int64_t sum0 = - static_cast(a[0]) + static_cast(a[1]) + - static_cast(a[2]) + static_cast(a[3]) + - static_cast(b[0]); - const int64_t sum1 = - static_cast(a[4]) + static_cast(a[5]) + - static_cast(a[6]) + static_cast(a[7]) + - static_cast(b[1]); - const int64_t sum2 = - static_cast(a[8]) + static_cast(a[9]) + - static_cast(a[10]) + static_cast(a[11]) + - static_cast(b[2]); - const int64_t sum3 = - static_cast(a[12]) + static_cast(a[13]) + - static_cast(a[14]) + static_cast(a[15]) + - static_cast(b[3]); - const int32_t sign0 = static_cast(sum0 >> 63); - const int32_t sign1 = static_cast(sum1 >> 63); - const int32_t sign2 = static_cast(sum2 >> 63); - const int32_t sign3 = static_cast(sum3 >> 63); - using Raw = typename detail::Raw128::type; - return BitCast( - d, - VFromD{Raw{ - (sign0 == (sum0 >> 31)) ? static_cast(sum0) - : static_cast(sign0 ^ 0x7FFFFFFF), - (sign1 == (sum1 >> 31)) ? static_cast(sum1) - : static_cast(sign1 ^ 0x7FFFFFFF), - (sign2 == (sum2 >> 31)) ? static_cast(sum2) - : static_cast(sign2 ^ 0x7FFFFFFF), - (sign3 == (sum3 >> 31)) - ? static_cast(sum3) - : static_cast(sign3 ^ 0x7FFFFFFF)}}); - } else // NOLINT -#endif - { - return BitCast(d, VFromD{vec_vsum4sbs(a, b)}); - } -} + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; -// Casts nominally uint32_t result to D. -template -HWY_INLINE VFromD AltivecVsum4ubs(D d, __vector unsigned char a, - __vector unsigned int b) { - const Repartition du32; -#ifdef __OPTIMIZE__ - if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { - const uint64_t sum0 = - static_cast(a[0]) + static_cast(a[1]) + - static_cast(a[2]) + static_cast(a[3]) + - static_cast(b[0]); - const uint64_t sum1 = - static_cast(a[4]) + static_cast(a[5]) + - static_cast(a[6]) + static_cast(a[7]) + - static_cast(b[1]); - const uint64_t sum2 = - static_cast(a[8]) + static_cast(a[9]) + - static_cast(a[10]) + static_cast(a[11]) + - static_cast(b[2]); - const uint64_t sum3 = - static_cast(a[12]) + static_cast(a[13]) + - static_cast(a[14]) + static_cast(a[15]) + - static_cast(b[3]); - return BitCast( - d, - VFromD{(__vector unsigned int){ - static_cast(sum0 <= 0xFFFFFFFFu ? sum0 : 0xFFFFFFFFu), - static_cast(sum1 <= 0xFFFFFFFFu ? sum1 : 0xFFFFFFFFu), - static_cast(sum2 <= 0xFFFFFFFFu ? sum2 : 0xFFFFFFFFu), - static_cast(sum3 <= 0xFFFFFFFFu ? sum3 - : 0xFFFFFFFFu)}}); - } else // NOLINT + // If T is an integer type, do an unsigned vec_sub to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw - + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_sub( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); #endif - { - return BitCast(d, VFromD{vec_vsum4ubs(a, b)}); - } } -// Casts nominally int32_t result to D. -template -HWY_INLINE VFromD AltivecVsum2sws(D d, __vector signed int a, - __vector signed int b) { - const Repartition di32; -#ifdef __OPTIMIZE__ - const Repartition du64; - constexpr int kDestLaneOffset = HWY_IS_BIG_ENDIAN; - if (IsConstantRawAltivecVect(a) && __builtin_constant_p(b[kDestLaneOffset]) && - __builtin_constant_p(b[kDestLaneOffset + 2])) { - const int64_t sum0 = static_cast(a[0]) + - static_cast(a[1]) + - static_cast(b[kDestLaneOffset]); - const int64_t sum1 = static_cast(a[2]) + - static_cast(a[3]) + - static_cast(b[kDestLaneOffset + 2]); - const int32_t sign0 = static_cast(sum0 >> 63); - const int32_t sign1 = static_cast(sum1 >> 63); - return BitCast(d, VFromD{(__vector unsigned long long){ - (sign0 == (sum0 >> 31)) - ? static_cast(sum0) - : static_cast(sign0 ^ 0x7FFFFFFF), - (sign1 == (sum1 >> 31)) - ? static_cast(sum1) - : static_cast(sign1 ^ 0x7FFFFFFF)}}); - } else // NOLINT -#endif - { - __vector signed int sum; - - // Inline assembly is used for vsum2sws to avoid unnecessary shuffling - // on little-endian PowerPC targets as the result of the vsum2sws - // instruction will already be in the correct lanes on little-endian - // PowerPC targets. - __asm__("vsum2sws %0,%1,%2" : "=v"(sum) : "v"(a), "v"(b)); - - return BitCast(d, VFromD{sum}); - } +// ------------------------------ SumsOf8 +template )> +HWY_API VFromD>> SumsOf8(V v) { + return SumsOf2(SumsOf4(v)); } -} // namespace detail - -template -HWY_API Vec128 SumsOf8(Vec128 v) { - const Repartition> du64; - const Repartition di32; - const RebindToUnsigned du32; +template )> +HWY_API VFromD>> SumsOf8(V v) { +#if HWY_S390X_HAVE_Z14 + const DFromV di8; + const RebindToUnsigned du8; + const RepartitionToWideX3 di64; - return detail::AltivecVsum2sws( - du64, detail::AltivecVsum4ubs(di32, v.raw, Zero(du32).raw).raw, - Zero(di32).raw); + return BitCast(di64, SumsOf8(BitCast(du8, Xor(v, SignBit(di8))))) + + Set(di64, int64_t{-1024}); +#else + return SumsOf2(SumsOf4(v)); +#endif } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. +#if HWY_S390X_HAVE_Z14 +// Z14/Z15/Z16 does not have I8/U8/I16/U16 SaturatedAdd instructions unlike most +// other integer SIMD instruction sets + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + return Add(a, Min(b, Not(a))); +} + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +#else // VSX + #ifdef HWY_NATIVE_I32_SATURATED_ADDSUB #undef HWY_NATIVE_I32_SATURATED_ADDSUB #else @@ -1386,6 +1479,7 @@ template SaturatedAdd(Vec128 a, Vec128 b) { return Vec128{vec_adds(a.raw, b.raw)}; } +#endif // HWY_S390X_HAVE_Z14 #if HWY_PPC_HAVE_10 @@ -1412,11 +1506,34 @@ HWY_API V SaturatedAdd(V a, V b) { // Returns a - b clamped to the destination range. +#if HWY_S390X_HAVE_Z14 +// Z14/Z15/Z16 does not have I8/U8/I16/U16 SaturatedSub instructions unlike most +// other integer SIMD instruction sets + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + return Sub(a, Min(a, b)); +} + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#else // VSX + template HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { return Vec128{vec_subs(a.raw, b.raw)}; } +#endif // HWY_S390X_HAVE_Z14 #if HWY_PPC_HAVE_10 @@ -1437,12 +1554,33 @@ HWY_API V SaturatedSub(V a, V b) { // Returns (a + b + 1) / 2 -template +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#if HWY_S390X_HAVE_Z14 +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#define HWY_PPC_IF_AVERAGE_ROUND_T(T) void* = nullptr +#else // !HWY_S390X_HAVE_Z14 +#define HWY_PPC_IF_AVERAGE_ROUND_T(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#endif // HWY_S390X_HAVE_Z14 + +template HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { return Vec128{vec_avg(a.raw, b.raw)}; } +#undef HWY_PPC_IF_AVERAGE_ROUND_T + // ------------------------------ Multiplication // Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. @@ -1459,33 +1597,97 @@ HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { template HWY_API Vec128 operator*(Vec128 a, Vec128 b) { - return Vec128{a.raw * b.raw}; + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_mul to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw * + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_mul( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. + +#if HWY_S390X_HAVE_Z14 +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + hwy::EnableIf()>* = nullptr +#elif HWY_PPC_HAVE_10 +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8)) +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2)) +#else +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#endif + +#if HWY_S390X_HAVE_Z14 || HWY_PPC_HAVE_10 +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + return Vec128{vec_mulh(a.raw, b.raw)}; +} +#endif + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const auto p_even = MulEven(a, b); + +#if HWY_IS_LITTLE_ENDIAN + const auto p_even_full = ResizeBitCast(Full128(), p_even); + return Vec128{ + vec_sld(p_even_full.raw, p_even_full.raw, 16 - sizeof(T))}; +#else + const DFromV d; + return ResizeBitCast(d, p_even); +#endif } -// Returns the upper 16 bits of a * b in each lane. -template +template HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { const DFromV d; - const RepartitionToWide dw; - const VFromD p1{vec_mule(a.raw, b.raw)}; - const VFromD p2{vec_mulo(a.raw, b.raw)}; + + const auto p_even = BitCast(d, MulEven(a, b)); + const auto p_odd = BitCast(d, MulOdd(a, b)); + #if HWY_IS_LITTLE_ENDIAN - const __vector unsigned char kShuffle = {2, 3, 18, 19, 6, 7, 22, 23, - 10, 11, 26, 27, 14, 15, 30, 31}; + return InterleaveOdd(d, p_even, p_odd); #else - const __vector unsigned char kShuffle = {0, 1, 16, 17, 4, 5, 20, 21, - 8, 9, 24, 25, 12, 13, 28, 29}; + return InterleaveEven(d, p_even, p_odd); #endif - return BitCast(d, VFromD{vec_perm(p1.raw, p2.raw, kShuffle)}); } -template -HWY_API Vec128 MulFixedPoint15(Vec128 a, - Vec128 b) { - const Vec128 zero = Zero(Full128()); - return Vec128{vec_mradds(a.raw, b.raw, zero.raw)}; +#if !HWY_PPC_HAVE_10 +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T p_hi; + Mul128(GetLane(a), GetLane(b), &p_hi); + return Set(Full64(), p_hi); } +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const DFromV d; + const Half dh; + return Combine(d, MulHigh(UpperHalf(dh, a), UpperHalf(dh, b)), + MulHigh(LowerHalf(dh, a), LowerHalf(dh, b))); +} +#endif // !HWY_PPC_HAVE_10 + +#undef HWY_PPC_IF_MULHIGH_USING_VEC_MULH +#undef HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH + // Multiplies even lanes (0, 2, ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. template , (N + 1) / 2> MulOdd(Vec128 a, return Vec128, (N + 1) / 2>{vec_mulo(a.raw, b.raw)}; } +// ------------------------------ Rol/Ror + +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_rl(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToSigned di; + return Rol(a, BitCast(d, Neg(BitCast(di, b)))); +} + // ------------------------------ RotateRight -template +template HWY_API Vec128 RotateRight(const Vec128 v) { const DFromV d; constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); - if (kBits == 0) return v; - return Vec128{vec_rl(v.raw, Set(d, kSizeInBits - kBits).raw)}; + + return (kBits == 0) + ? v + : Rol(v, Set(d, static_cast(static_cast(kSizeInBits) - + kBits))); } -// ------------------------------ ZeroIfNegative (BroadcastSignBit) -template -HWY_API Vec128 ZeroIfNegative(Vec128 v) { - static_assert(IsFloat(), "Only works for float"); +// ------------------------------ RotateLeftSame/RotateRightSame +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +template +HWY_API Vec128 RotateLeftSame(Vec128 v, int bits) { const DFromV d; - const RebindToSigned di; - const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); - return IfThenElse(mask, Zero(d), v); + return Rol(v, Set(d, static_cast(static_cast(bits)))); +} + +template +HWY_API Vec128 RotateRightSame(Vec128 v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast(0u - static_cast(bits)))); } // ------------------------------ IfNegativeThenElse @@ -1541,10 +1802,35 @@ HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, BitCast(du, no).raw, BitCast(du, yes).raw, BitCast(du, v).raw)}); #else const RebindToSigned di; - return IfThenElse(MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))), - yes, no); + return IfVecThenElse(BitCast(d, BroadcastSignBit(BitCast(di, v))), yes, no); +#endif +} + +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE #endif + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + const DFromV d; + return IfNegativeThenElse(v, yes, Zero(d)); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + const DFromV d; + return IfNegativeThenElse(v, Zero(d), no); } +#endif // generic_ops takes care of integer T. template @@ -1598,17 +1884,42 @@ HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, #endif template -HWY_API Vec128 ApproximateReciprocal(Vec128 v) { - return Vec128{vec_re(v.raw)}; +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { +#if HWY_S390X_HAVE_Z14 + return Vec128{a.raw / b.raw}; +#else + return Vec128{vec_div(a.raw, b.raw)}; +#endif } template -HWY_API Vec128 operator/(Vec128 a, Vec128 b) { - return Vec128{vec_div(a.raw, b.raw)}; +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + return Set(d, T(1.0)) / v; +#else + return Vec128{vec_re(v.raw)}; +#endif } // ------------------------------ Floating-point square root +#if HWY_S390X_HAVE_Z14 +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + const auto half = v * Set(d, 0.5f); + // Initial guess based on log2(f) + const auto guess = BitCast( + d, Set(du, uint32_t{0x5F3759DFu}) - ShiftRight<1>(BitCast(du, v))); + // One Newton-Raphson iteration + return guess * NegMulAdd(half * guess, guess, Set(d, 1.5f)); +} +#else // VSX + #ifdef HWY_NATIVE_F64_APPROX_RSQRT #undef HWY_NATIVE_F64_APPROX_RSQRT #else @@ -1620,6 +1931,7 @@ template HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { return Vec128{vec_rsqrte(v.raw)}; } +#endif // HWY_S390X_HAVE_Z14 // Full precision square root template @@ -1668,6 +1980,167 @@ HWY_API V AbsDiff(const V a, const V b) { #endif // HWY_PPC_HAVE_9 +// ------------------------------ Integer Div for PPC10 +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for I32 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I32 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed int raw_result; + __asm__("vdivsw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for U32 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U32 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned int raw_result; + __asm__("vdivuw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for I64 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I64 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed long long raw_result; + __asm__("vdivsd %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for U64 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U64 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned long long raw_result; + __asm__("vdivud %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To(d, PromoteLowerTo(dw, a) / PromoteLowerTo(dw, b), + PromoteUpperTo(dw, a) / PromoteUpperTo(dw, b)); +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + const DFromV d; + const Rebind, decltype(d)> dw; + return DemoteTo(d, PromoteTo(dw, a) / PromoteTo(dw, b)); +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for I32 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I32 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed int raw_result; + __asm__("vmodsw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for U32 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U32 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned int raw_result; + __asm__("vmoduw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for I64 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I64 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed long long raw_result; + __asm__("vmodsd %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for U64 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U64 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned long long raw_result; + __asm__("vmodud %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To(d, PromoteLowerTo(dw, a) % PromoteLowerTo(dw, b), + PromoteUpperTo(dw, a) % PromoteUpperTo(dw, b)); +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + const DFromV d; + const Rebind, decltype(d)> dw; + return DemoteTo(d, PromoteTo(dw, a) % PromoteTo(dw, b)); +} +#endif + // ================================================== MEMORY (3) // ------------------------------ Non-temporal stores @@ -1800,7 +2273,7 @@ template HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { #if HWY_IS_LITTLE_ENDIAN typename detail::Raw128::type raw_result = v.raw; - raw_result[i] = t; + raw_result[i] = BitCastScalar::RawT>(t); return Vec128{raw_result}; #else // On ppc64be without this, mul_test fails, but swizzle_test passes. @@ -2070,7 +2543,7 @@ HWY_API Vec32 Reverse(D d, Vec32 v) { // ------------------------------- ReverseLaneBytes -#if HWY_PPC_HAVE_9 && \ +#if (HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14) && \ (HWY_COMPILER_GCC_ACTUAL >= 710 || HWY_COMPILER_CLANG >= 400) // Per-target flag to prevent generic_ops-inl.h defining 8-bit ReverseLaneBytes. @@ -2111,7 +2584,7 @@ HWY_API VFromD Reverse8(D d, VFromD v) { return BitCast(d, ReverseLaneBytes(BitCast(du64, v))); } -#endif // HWY_PPC_HAVE_9 +#endif // HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 template , HWY_IF_T_SIZE(T, 1)> HWY_API Vec16 Reverse(D d, Vec16 v) { @@ -2268,11 +2741,15 @@ HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { Set(Full128(), static_cast(amt * sizeof(TFromD) * 8))); +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU8{vec_srb(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else // VSX #if HWY_IS_LITTLE_ENDIAN return BitCast(d, VU8{vec_slo(BitCast(du8, v).raw, v_shift_amt.raw)}); #else return BitCast(d, VU8{vec_sro(BitCast(du8, v).raw, v_shift_amt.raw)}); -#endif +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 } // ------------------------------ SlideDownLanes @@ -2300,11 +2777,15 @@ HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { Set(Full128(), static_cast(amt * sizeof(TFromD) * 8))); +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU8{vec_slb(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else // VSX #if HWY_IS_LITTLE_ENDIAN return BitCast(d, VU8{vec_sro(BitCast(du8, v).raw, v_shift_amt.raw)}); #else return BitCast(d, VU8{vec_slo(BitCast(du8, v).raw, v_shift_amt.raw)}); -#endif +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 } // ================================================== COMBINE @@ -2637,7 +3118,15 @@ HWY_API Vec128 DupEven(Vec128 v) { template HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 0, 1, 2, 3, 0, 1, 2, 3, 8, 9, 10, + 11, 8, 9, 10, 11))); +#else return Vec128{vec_mergee(v.raw, v.raw)}; +#endif } // ------------------------------ DupOdd (InterleaveUpper) @@ -2662,7 +3151,15 @@ HWY_API Vec128 DupOdd(Vec128 v) { template HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, + 15, 12, 13, 14, 15))); +#else return Vec128{vec_mergeo(v.raw, v.raw)}; +#endif } template @@ -2706,6 +3203,96 @@ HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); } +// ------------------------------ InterleaveEven + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, + 10, 26, 12, 28, 14, 30) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{Dup128VecFromValues(Full128(), 0, 1, + 16, 17, 4, 5, 20, 21, 8, + 9, 24, 25, 12, 13, 28, 29) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { +#if HWY_S390X_HAVE_Z14 + const Full128> d_full; + const Indices128> idx{Dup128VecFromValues(Full128(), 0, 1, + 2, 3, 16, 17, 18, 19, 8, + 9, 10, 11, 24, 25, 26, 27) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +#else + (void)d; + return VFromD{vec_mergee(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, + 11, 27, 13, 29, 15, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 2, 3, 18, 19, 6, 7, 22, 23, 10, + 11, 26, 27, 14, 15, 30, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { +#if HWY_S390X_HAVE_Z14 + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 4, 5, 6, 7, 20, 21, 22, 23, 12, + 13, 14, 15, 28, 29, 30, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +#else + (void)d; + return VFromD{vec_mergeo(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + // ------------------------------ OddEvenBlocks template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { @@ -2719,14 +3306,52 @@ HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { return v; } -// ------------------------------ Shl +// ------------------------------ MulFixedPoint15 (OddEven) -namespace detail { -template -HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, - Vec128 bits) { - return Vec128{vec_sl(v.raw, bits.raw)}; -} +#if HWY_S390X_HAVE_Z14 +HWY_API Vec16 MulFixedPoint15(Vec16 a, Vec16 b) { + const DFromV di16; + const RepartitionToWide di32; + + const auto round_up_incr = Set(di32, 0x4000); + const auto i32_product = MulEven(a, b) + round_up_incr; + + return ResizeBitCast(di16, ShiftLeft<1>(i32_product)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const DFromV di16; + const RepartitionToWide di32; + + const auto round_up_incr = Set(di32, 0x4000); + const auto even_product = MulEven(a, b) + round_up_incr; + const auto odd_product = MulOdd(a, b) + round_up_incr; + + return OddEven(BitCast(di16, ShiftRight<15>(odd_product)), + BitCast(di16, ShiftLeft<1>(even_product))); +} +#else +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const Vec128 zero = Zero(Full128()); + return Vec128{vec_mradds(a.raw, b.raw, zero.raw)}; +} +#endif + +// ------------------------------ Shl + +namespace detail { +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw << bits.raw}; +#else + return Vec128{vec_sl(v.raw, bits.raw)}; +#endif +} // Signed left shift is the same as unsigned. template @@ -2751,15 +3376,23 @@ namespace detail { template HWY_API Vec128 Shr(hwy::UnsignedTag /*tag*/, Vec128 v, Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> bits.raw}; +#else return Vec128{vec_sr(v.raw, bits.raw)}; +#endif } template HWY_API Vec128 Shr(hwy::SignedTag /*tag*/, Vec128 v, Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> bits.raw}; +#else const DFromV di; const RebindToUnsigned du; return Vec128{vec_sra(v.raw, BitCast(du, bits).raw)}; +#endif } } // namespace detail @@ -2771,100 +3404,85 @@ HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { // ------------------------------ MulEven/Odd 64x64 (UpperHalf) -HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { +template +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { #if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) - using VU64 = __vector unsigned long long; - const VU64 mul128_result = reinterpret_cast(vec_mule(a.raw, b.raw)); + using V64 = typename detail::Raw128::type; + const V64 mul128_result = reinterpret_cast(vec_mule(a.raw, b.raw)); #if HWY_IS_LITTLE_ENDIAN - return Vec128{mul128_result}; + return Vec128{mul128_result}; #else // Need to swap the two halves of mul128_result on big-endian targets as // the upper 64 bits of the product are in lane 0 of mul128_result and // the lower 64 bits of the product are in lane 1 of mul128_result - return Vec128{vec_sld(mul128_result, mul128_result, 8)}; + return Vec128{vec_sld(mul128_result, mul128_result, 8)}; #endif #else - alignas(16) uint64_t mul[2]; + alignas(16) T mul[2]; mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); - return Load(Full128(), mul); + return Load(Full128(), mul); #endif } -HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { +template +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { #if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) - using VU64 = __vector unsigned long long; - const VU64 mul128_result = reinterpret_cast(vec_mulo(a.raw, b.raw)); + using V64 = typename detail::Raw128::type; + const V64 mul128_result = reinterpret_cast(vec_mulo(a.raw, b.raw)); #if HWY_IS_LITTLE_ENDIAN - return Vec128{mul128_result}; + return Vec128{mul128_result}; #else // Need to swap the two halves of mul128_result on big-endian targets as // the upper 64 bits of the product are in lane 0 of mul128_result and // the lower 64 bits of the product are in lane 1 of mul128_result - return Vec128{vec_sld(mul128_result, mul128_result, 8)}; + return Vec128{vec_sld(mul128_result, mul128_result, 8)}; #endif #else - alignas(16) uint64_t mul[2]; - const Full64 d2; + alignas(16) T mul[2]; + const Full64 d2; mul[0] = Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); - return Load(Full128(), mul); + return Load(Full128(), mul); #endif } +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + // ------------------------------ WidenMulPairwiseAdd -template >> -HWY_API VFromD WidenMulPairwiseAdd(D32 df32, V16 a, V16 b) { - const RebindToUnsigned du32; - // Lane order within sum0/1 is undefined, hence we can avoid the - // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip - // leads to the odd/even order that RearrangeToOddPlusEven prefers. - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), - Mul(BitCast(df32, ao), BitCast(df32, bo))); +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); } // Even if N=1, the input is always at least 2 lanes, hence vec_msum is safe. template >> HWY_API VFromD WidenMulPairwiseAdd(D32 d32, V16 a, V16 b) { +#if HWY_S390X_HAVE_Z14 + (void)d32; + return MulEven(a, b) + MulOdd(a, b); +#else return VFromD{vec_msum(a.raw, b.raw, Zero(d32).raw)}; +#endif } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) -template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, - VFromD sum0, - VFromD& sum1) { - const RebindToUnsigned du32; - // Lane order within sum0/1 is undefined, hence we can avoid the - // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip - // leads to the odd/even order that RearrangeToOddPlusEven prefers. - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); -} - // Even if N=1, the input is always at least 2 lanes, hence vec_msum is safe. template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 /* tag */, V16 a, V16 b, +HWY_API VFromD ReorderWidenMulAccumulate(D32 /*d32*/, V16 a, V16 b, VFromD sum0, VFromD& /*sum1*/) { +#if HWY_S390X_HAVE_Z14 + return MulEven(a, b) + MulOdd(a, b) + sum0; +#else return VFromD{vec_msum(a.raw, b.raw, sum0.raw)}; +#endif } // ------------------------------ RearrangeToOddPlusEven @@ -2885,7 +3503,27 @@ HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { return Add(sum0, sum1); } +// ------------------------------ SatWidenMulPairwiseAccumulate +#if !HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{vec_msums(a.raw, b.raw, sum.raw)}; +} + +#endif // !HWY_S390X_HAVE_Z14 + // ------------------------------ SumOfMulQuadAccumulate +#if !HWY_S390X_HAVE_Z14 + #ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE #undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE #else @@ -2925,11 +3563,12 @@ HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, const auto result_sum_0 = SumOfMulQuadAccumulate(di32, BitCast(du8, a), b, sum); - const auto result_sum_1 = ShiftLeft<8>(detail::AltivecVsum4sbs( - di32, And(b, BroadcastSignBit(a)).raw, Zero(di32).raw)); + const auto result_sum_1 = ShiftLeft<8>(SumsOf4(And(b, BroadcastSignBit(a)))); return result_sum_0 - result_sum_1; } +#endif // !HWY_S390X_HAVE_Z14 + // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) @@ -3018,29 +3657,59 @@ HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { } template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +HWY_API VFromD PromoteTo(D df64, VFromD> v) { +#if HWY_S390X_HAVE_Z14 + const RebindToSigned di64; + return ConvertTo(df64, PromoteTo(di64, v)); +#else // VSX + (void)df64; const __vector signed int raw_v = InterleaveLower(v, v).raw; #if HWY_IS_LITTLE_ENDIAN return VFromD{vec_doubleo(raw_v)}; #else return VFromD{vec_doublee(raw_v)}; #endif +#endif // HWY_S390X_HAVE_Z14 } template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +HWY_API VFromD PromoteTo(D df64, VFromD> v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du64; + return ConvertTo(df64, PromoteTo(du64, v)); +#else // VSX + (void)df64; const __vector unsigned int raw_v = InterleaveLower(v, v).raw; #if HWY_IS_LITTLE_ENDIAN return VFromD{vec_doubleo(raw_v)}; #else return VFromD{vec_doublee(raw_v)}; #endif +#endif // HWY_S390X_HAVE_Z14 +} + +#if !HWY_S390X_HAVE_Z14 +namespace detail { + +template +static HWY_INLINE V VsxF2INormalizeSrcVals(V v) { +#if !defined(HWY_DISABLE_PPC_VSX_QEMU_F2I_WORKAROUND) + // Workaround for QEMU 7/8 VSX float to int conversion bug + return IfThenElseZero(v == v, v); +#else + return v; +#endif } +} // namespace detail +#endif // !HWY_S390X_HAVE_Z14 + template HWY_API VFromD PromoteTo(D di64, VFromD> v) { -#if HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds) - const __vector float raw_v = InterleaveLower(v, v).raw; +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds)) + const __vector float raw_v = + detail::VsxF2INormalizeSrcVals(InterleaveLower(v, v)).raw; return VFromD{__builtin_vsx_xvcvspsxds(raw_v)}; #else const RebindToFloat df64; @@ -3050,8 +3719,10 @@ HWY_API VFromD PromoteTo(D di64, VFromD> v) { template HWY_API VFromD PromoteTo(D du64, VFromD> v) { -#if HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds) - const __vector float raw_v = InterleaveLower(v, v).raw; +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds)) + const __vector float raw_v = + detail::VsxF2INormalizeSrcVals(InterleaveLower(v, v)).raw; return VFromD{reinterpret_cast<__vector unsigned long long>( __builtin_vsx_xvcvspuxds(raw_v))}; #else @@ -3123,7 +3794,12 @@ HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { } template -HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { +HWY_API VFromD PromoteUpperTo(D df64, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToSigned di64; + return ConvertTo(df64, PromoteUpperTo(di64, v)); +#else // VSX + (void)df64; const __vector signed int raw_v = InterleaveUpper(Full128(), v, v).raw; #if HWY_IS_LITTLE_ENDIAN @@ -3131,10 +3807,16 @@ HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { #else return VFromD{vec_doublee(raw_v)}; #endif +#endif // HWY_S390X_HAVE_Z14 } template -HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { +HWY_API VFromD PromoteUpperTo(D df64, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du64; + return ConvertTo(df64, PromoteUpperTo(du64, v)); +#else // VSX + (void)df64; const __vector unsigned int raw_v = InterleaveUpper(Full128(), v, v).raw; #if HWY_IS_LITTLE_ENDIAN @@ -3142,12 +3824,16 @@ HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { #else return VFromD{vec_doublee(raw_v)}; #endif +#endif // HWY_S390X_HAVE_Z14 } template HWY_API VFromD PromoteUpperTo(D di64, Vec128 v) { -#if HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds) - const __vector float raw_v = InterleaveUpper(Full128(), v, v).raw; +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds)) + const __vector float raw_v = + detail::VsxF2INormalizeSrcVals(InterleaveUpper(Full128(), v, v)) + .raw; return VFromD{__builtin_vsx_xvcvspsxds(raw_v)}; #else const RebindToFloat df64; @@ -3157,8 +3843,11 @@ HWY_API VFromD PromoteUpperTo(D di64, Vec128 v) { template HWY_API VFromD PromoteUpperTo(D du64, Vec128 v) { -#if HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds) - const __vector float raw_v = InterleaveUpper(Full128(), v, v).raw; +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds)) + const __vector float raw_v = + detail::VsxF2INormalizeSrcVals(InterleaveUpper(Full128(), v, v)) + .raw; return VFromD{reinterpret_cast<__vector unsigned long long>( __builtin_vsx_xvcvspuxds(raw_v))}; #else @@ -3174,6 +3863,219 @@ HWY_API VFromD PromoteUpperTo(D d, V v) { return PromoteTo(d, UpperHalf(dh, v)); } +// ------------------------------ PromoteEvenTo/PromoteOddTo + +namespace detail { + +// Signed to Signed PromoteEvenTo/PromoteOddTo for PPC9/PPC10 +#if HWY_PPC_HAVE_9 && \ + (HWY_COMPILER_GCC_ACTUAL >= 1200 || HWY_COMPILER_CLANG >= 1200) + +#if HWY_IS_LITTLE_ENDIAN +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signexti(v.raw)}; +} +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signextll(v.raw)}; +} +#else +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signexti(v.raw)}; +} +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signextll(v.raw)}; +} +#endif // HWY_IS_LITTLE_ENDIAN + +#endif // HWY_PPC_HAVE_9 + +// I32/U32/F32->F64 PromoteEvenTo +#if HWY_S390X_HAVE_Z14 +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doublee(v.raw)}; +} +template )> +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const Rebind>, decltype(d_to)> dw; + return ConvertTo(d_to, PromoteEvenTo(dw, v)); +} +#else // VSX +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doublee(v.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +// F32->I64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds)) + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // __builtin_vsx_xvcvspsxds expects the source values to be in the odd lanes + // on little-endian PPC, and the vec_sld operation below will shift the even + // lanes of normalized_v into the odd lanes. + return VFromD{ + __builtin_vsx_xvcvspsxds(vec_sld(normalized_v.raw, normalized_v.raw, 4))}; +#else + // __builtin_vsx_xvcvspsxds expects the source values to be in the even lanes + // on big-endian PPC. + return VFromD{__builtin_vsx_xvcvspsxds(normalized_v.raw)}; +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// F32->U64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds)) + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // __builtin_vsx_xvcvspuxds expects the source values to be in the odd lanes + // on little-endian PPC, and the vec_sld operation below will shift the even + // lanes of normalized_v into the odd lanes. + return VFromD{ + reinterpret_cast<__vector unsigned long long>(__builtin_vsx_xvcvspuxds( + vec_sld(normalized_v.raw, normalized_v.raw, 4)))}; +#else + // __builtin_vsx_xvcvspuxds expects the source values to be in the even lanes + // on big-endian PPC. + return VFromD{reinterpret_cast<__vector unsigned long long>( + __builtin_vsx_xvcvspuxds(normalized_v.raw))}; +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// I32/U32/F32->F64 PromoteOddTo +#if HWY_S390X_HAVE_Z14 +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { + return PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), hwy::FloatTag(), + d_to, V{vec_sld(v.raw, v.raw, 4)}); +} +template )> +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const Rebind>, decltype(d_to)> dw; + return ConvertTo(d_to, PromoteOddTo(dw, v)); +} +#else +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doubleo(v.raw)}; +} +#endif + +// F32->I64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds)) + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // __builtin_vsx_xvcvspsxds expects the source values to be in the odd lanes + // on little-endian PPC + return VFromD{__builtin_vsx_xvcvspsxds(normalized_v.raw)}; +#else + // __builtin_vsx_xvcvspsxds expects the source values to be in the even lanes + // on big-endian PPC, and the vec_sld operation below will shift the odd lanes + // of normalized_v into the even lanes. + return VFromD{ + __builtin_vsx_xvcvspsxds(vec_sld(normalized_v.raw, normalized_v.raw, 4))}; +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteOddTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// F32->U64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds)) + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // __builtin_vsx_xvcvspuxds expects the source values to be in the odd lanes + // on little-endian PPC + return VFromD{reinterpret_cast<__vector unsigned long long>( + __builtin_vsx_xvcvspuxds(normalized_v.raw))}; +#else + // __builtin_vsx_xvcvspuxds expects the source values to be in the even lanes + // on big-endian PPC, and the vec_sld operation below will shift the odd lanes + // of normalized_v into the even lanes. + return VFromD{ + reinterpret_cast<__vector unsigned long long>(__builtin_vsx_xvcvspuxds( + vec_sld(normalized_v.raw, normalized_v.raw, 4)))}; +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteOddTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +} // namespace detail + // ------------------------------ Demotions (full -> part w/ narrow lanes) template DemoteTo(D df16, VFromD> v) { #endif // HWY_PPC_HAVE_9 -template -HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { - const Rebind du32; // for logical shift right - const Rebind du16; - const auto bits_in_32 = ShiftRight<16>(BitCast(du32, v)); - return BitCast(dbf16, TruncateTo(du16, bits_in_32)); +#if HWY_PPC_HAVE_9 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +namespace detail { + +// On big-endian PPC9, VsxXscvdphp converts vf64[0] to a F16, returned as an U64 +// vector with the resulting F16 bits in the lower 16 bits of U64 lane 0 + +// On little-endian PPC9, VsxXscvdphp converts vf64[1] to a F16, returned as +// an U64 vector with the resulting F16 bits in the lower 16 bits of U64 lane 1 +static HWY_INLINE Vec128 VsxXscvdphp(Vec128 vf64) { + // Inline assembly is needed for the PPC9 xscvdphp instruction as there is + // currently no intrinsic available for the PPC9 xscvdphp instruction + __vector unsigned long long raw_result; + __asm__("xscvdphp %x0, %x1" : "=wa"(raw_result) : "wa"(vf64.raw)); + return Vec128{raw_result}; } -template >> -HWY_API VFromD ReorderDemote2To(D dbf16, V32 a, V32 b) { - const RebindToUnsigned du16; - const Repartition du32; +} // namespace detail + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du64; + + const Full128 df64_full; #if HWY_IS_LITTLE_ENDIAN - const auto a_in_odd = a; - const auto b_in_even = ShiftRight<16>(BitCast(du32, b)); + const auto bits16_as_u64 = + UpperHalf(du64, detail::VsxXscvdphp(Combine(df64_full, v, v))); #else - const auto a_in_odd = ShiftRight<16>(BitCast(du32, a)); - const auto b_in_even = b; + const auto bits16_as_u64 = + LowerHalf(du64, detail::VsxXscvdphp(ResizeBitCast(df64_full, v))); #endif - return BitCast(dbf16, - OddEven(BitCast(du16, a_in_odd), BitCast(du16, b_in_even))); + + return BitCast(df16, TruncateTo(du16, bits16_as_u64)); } +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du64; + const Rebind df64; + +#if HWY_IS_LITTLE_ENDIAN + const auto bits64_as_u64_0 = detail::VsxXscvdphp(InterleaveLower(df64, v, v)); + const auto bits64_as_u64_1 = detail::VsxXscvdphp(v); + const auto bits64_as_u64 = + InterleaveUpper(du64, bits64_as_u64_0, bits64_as_u64_1); +#else + const auto bits64_as_u64_0 = detail::VsxXscvdphp(v); + const auto bits64_as_u64_1 = detail::VsxXscvdphp(InterleaveUpper(df64, v, v)); + const auto bits64_as_u64 = + InterleaveLower(du64, bits64_as_u64_0, bits64_as_u64_1); +#endif + + return BitCast(df16, TruncateTo(du16, bits64_as_u64)); +} + +#elif HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +namespace detail { + +template +static HWY_INLINE VFromD DemoteToF32WithRoundToOdd( + DF32 df32, VFromD> v) { + const Twice dt_f32; + + __vector float raw_f32_in_even; + __asm__("vledb %0,%1,0,3" : "=v"(raw_f32_in_even) : "v"(v.raw)); + + const VFromD f32_in_even{raw_f32_in_even}; + return LowerHalf(df32, ConcatEven(dt_f32, f32_in_even, f32_in_even)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df32; + return DemoteTo(df16, detail::DemoteToF32WithRoundToOdd(df32, v)); +} + +#endif // HWY_PPC_HAVE_9 + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { + +// VsxXvcvspbf16 converts a F32 vector to a BF16 vector, bitcasted to an U32 +// vector with the resulting BF16 bits in the lower 16 bits of each U32 lane +template +static HWY_INLINE VFromD> VsxXvcvspbf16( + D dbf16, VFromD> v) { + const Rebind du32; + const Repartition du32_as_du8; + + using VU32 = __vector unsigned int; + + // Even though the __builtin_vsx_xvcvspbf16 builtin performs a F32 to BF16 + // conversion, the __builtin_vsx_xvcvspbf16 intrinsic expects a + // __vector unsigned char argument (at least as of GCC 13 and Clang 17) + return VFromD>{reinterpret_cast( + __builtin_vsx_xvcvspbf16(BitCast(du32_as_du8, v).raw))}; +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { + const RebindToUnsigned du16; + return BitCast(dbf16, TruncateTo(du16, detail::VsxXvcvspbf16(dbf16, v))); +} + +#endif // HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) + // Specializations for partial vectors because vec_packs sets lanes above 2*N. template ReorderDemote2To(DN /*dn*/, V a, V b) { return VFromD{vec_packs(a.raw, b.raw)}; } +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V) * 2)> +HWY_API VFromD ReorderDemote2To(D dbf16, V a, V b) { + const RebindToUnsigned du16; + const Half dh_bf16; + return BitCast(dbf16, + OrderedTruncate2To(du16, detail::VsxXvcvspbf16(dh_bf16, a), + detail::VsxXvcvspbf16(dh_bf16, b))); +} +#endif + template ), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), @@ -3376,15 +4399,13 @@ HWY_API VFromD OrderedDemote2To(D d, V a, V b) { return ReorderDemote2To(d, a, b); } -template >> -HWY_API VFromD OrderedDemote2To(D dbf16, V32 a, V32 b) { - const RebindToUnsigned du16; -#if HWY_IS_LITTLE_ENDIAN - return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); -#else - return BitCast(dbf16, ConcatEven(du16, BitCast(du16, b), BitCast(du16, a))); -#endif +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); } +#endif template HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { @@ -3393,90 +4414,164 @@ HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { template HWY_API Vec64 DemoteTo(D d, Vec128 v) { -#if HWY_IS_LITTLE_ENDIAN +#if HWY_S390X_HAVE_Z14 || HWY_IS_LITTLE_ENDIAN const Vec128 f64_to_f32{vec_floate(v.raw)}; #else const Vec128 f64_to_f32{vec_floato(v.raw)}; #endif +#if HWY_S390X_HAVE_Z14 + const Twice dt; + return LowerHalf(d, ConcatEven(dt, f64_to_f32, f64_to_f32)); +#else const RebindToUnsigned du; const Rebind du64; return Vec64{ BitCast(d, TruncateTo(du, BitCast(du64, f64_to_f32))).raw}; +#endif } template -HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { - return Vec32{vec_signede(v.raw)}; +HWY_API Vec32 DemoteTo(D di32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind di64; + return DemoteTo(di32, ConvertTo(di64, v)); +#else + (void)di32; + return Vec32{vec_signede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif } template -HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { +HWY_API Vec64 DemoteTo(D di32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind di64; + return DemoteTo(di32, ConvertTo(di64, v)); +#else + (void)di32; + #if HWY_IS_LITTLE_ENDIAN - const Vec128 f64_to_i32{vec_signede(v.raw)}; + const Vec128 f64_to_i32{ + vec_signede(detail::VsxF2INormalizeSrcVals(v).raw)}; #else - const Vec128 f64_to_i32{vec_signedo(v.raw)}; + const Vec128 f64_to_i32{ + vec_signedo(detail::VsxF2INormalizeSrcVals(v).raw)}; #endif const Rebind di64; const Vec128 vi64 = BitCast(di64, f64_to_i32); return Vec64{vec_pack(vi64.raw, vi64.raw)}; +#endif } template -HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { - return Vec32{vec_unsignede(v.raw)}; +HWY_API Vec32 DemoteTo(D du32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +#else + (void)du32; + return Vec32{vec_unsignede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif } template -HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { +HWY_API Vec64 DemoteTo(D du32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +#else + (void)du32; #if HWY_IS_LITTLE_ENDIAN - const Vec128 f64_to_u32{vec_unsignede(v.raw)}; + const Vec128 f64_to_u32{ + vec_unsignede(detail::VsxF2INormalizeSrcVals(v).raw)}; #else - const Vec128 f64_to_u32{vec_unsignedo(v.raw)}; + const Vec128 f64_to_u32{ + vec_unsignedo(detail::VsxF2INormalizeSrcVals(v).raw)}; #endif const Rebind du64; const Vec128 vu64 = BitCast(du64, f64_to_u32); return Vec64{vec_pack(vu64.raw, vu64.raw)}; +#endif +} + +#if HWY_S390X_HAVE_Z14 +namespace detail { + +template )> +HWY_INLINE VFromD>> ConvToF64WithRoundToOdd(V v) { + __vector double raw_result; + // Use inline assembly to do a round-to-odd I64->F64 conversion on Z14 + __asm__("vcdgb %0,%1,0,3" : "=v"(raw_result) : "v"(v.raw)); + return VFromD>>{raw_result}; +} + +template )> +HWY_INLINE VFromD>> ConvToF64WithRoundToOdd(V v) { + __vector double raw_result; + // Use inline assembly to do a round-to-odd U64->F64 conversion on Z14 + __asm__("vcdlgb %0,%1,0,3" : "=v"(raw_result) : "v"(v.raw)); + return VFromD>>{raw_result}; } +} // namespace detail +#endif // HWY_S390X_HAVE_Z14 + template -HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { +HWY_API Vec32 DemoteTo(D df32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX + (void)df32; return Vec32{vec_floate(v.raw)}; +#endif } template -HWY_API Vec64 DemoteTo(D d, Vec128 v) { +HWY_API Vec64 DemoteTo(D df32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX #if HWY_IS_LITTLE_ENDIAN const Vec128 i64_to_f32{vec_floate(v.raw)}; #else const Vec128 i64_to_f32{vec_floato(v.raw)}; #endif - const RebindToUnsigned du; - const Rebind du64; + const RebindToUnsigned du32; + const Rebind du64; return Vec64{ - BitCast(d, TruncateTo(du, BitCast(du64, i64_to_f32))).raw}; + BitCast(df32, TruncateTo(du32, BitCast(du64, i64_to_f32))).raw}; +#endif } template -HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { +HWY_API Vec32 DemoteTo(D df32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX + (void)df32; return Vec32{vec_floate(v.raw)}; +#endif } template -HWY_API Vec64 DemoteTo(D d, Vec128 v) { +HWY_API Vec64 DemoteTo(D df32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX #if HWY_IS_LITTLE_ENDIAN const Vec128 u64_to_f32{vec_floate(v.raw)}; #else const Vec128 u64_to_f32{vec_floato(v.raw)}; #endif - const RebindToUnsigned du; - const Rebind du64; + const RebindToUnsigned du; + const Rebind du64; return Vec64{ - BitCast(d, TruncateTo(du, BitCast(du64, u64_to_f32))).raw}; + BitCast(df32, TruncateTo(du, BitCast(du64, u64_to_f32))).raw}; +#endif } // For already range-limited input [0, 255]. @@ -3491,17 +4586,39 @@ HWY_API Vec128 U8FromU32(Vec128 v) { // Note: altivec.h vec_ct* currently contain C casts which triggers // -Wdeprecate-lax-vec-conv-all warnings, so disable them. -template +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D df32, + Vec128().MaxLanes()> v) { + const Rebind df64; + return DemoteTo(df32, PromoteTo(df64, v)); +} +template +HWY_API VFromD ConvertTo(D df32, Vec128 v) { + const RepartitionToWide df64; + + const VFromD vf32_lo{vec_floate(PromoteLowerTo(df64, v).raw)}; + const VFromD vf32_hi{vec_floate(PromoteUpperTo(df64, v).raw)}; + return ConcatEven(df32, vf32_hi, vf32_lo); +} +#else // Z15 or PPC +template HWY_API VFromD ConvertTo(D /* tag */, Vec128().MaxLanes()> v) { HWY_DIAGNOSTICS(push) #if HWY_COMPILER_CLANG HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") #endif +#if HWY_S390X_HAVE_Z15 + return VFromD{vec_float(v.raw)}; +#else return VFromD{vec_ctf(v.raw, 0)}; +#endif HWY_DIAGNOSTICS(pop) } +#endif // HWY_TARGET == HWY_Z14 template @@ -3511,38 +4628,195 @@ HWY_API VFromD ConvertTo(D /* tag */, } // Truncates (rounds toward zero). -template +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D di32, + Vec128().MaxLanes()> v) { + const Rebind di64; + return DemoteTo(di32, PromoteTo(di64, v)); +} +template +HWY_API VFromD ConvertTo(D di32, + Vec128().MaxLanes()> v) { + const RepartitionToWide di64; + return OrderedDemote2To(di32, PromoteLowerTo(di64, v), + PromoteUpperTo(di64, v)); +} +#else // Z15 or PPC +template HWY_API VFromD ConvertTo(D /* tag */, - Vec128().MaxLanes()> v) { + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr int32_t kMinI32 = LimitsMin(); + constexpr int32_t kMaxI32 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= -2147483648.0f) + ? ((v.raw[0] < 2147483648.0f) ? static_cast(v.raw[0]) + : kMaxI32) + : ((v.raw[0] < 0) ? kMinI32 : 0), + (v.raw[1] >= -2147483648.0f) + ? ((v.raw[1] < 2147483648.0f) ? static_cast(v.raw[1]) + : kMaxI32) + : ((v.raw[1] < 0) ? kMinI32 : 0), + (v.raw[2] >= -2147483648.0f) + ? ((v.raw[2] < 2147483648.0f) ? static_cast(v.raw[2]) + : kMaxI32) + : ((v.raw[2] < 0) ? kMinI32 : 0), + (v.raw[3] >= -2147483648.0f) + ? ((v.raw[3] < 2147483648.0f) ? static_cast(v.raw[3]) + : kMaxI32) + : ((v.raw[3] < 0) ? kMinI32 : 0)); + } +#endif + +#if HWY_S390X_HAVE_Z15 + // Use inline assembly on Z15 to avoid undefined behavior if v[i] is not in + // the range of an int32_t + __vector signed int raw_result; + __asm__("vcfeb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else HWY_DIAGNOSTICS(push) #if HWY_COMPILER_CLANG HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") #endif return VFromD{vec_cts(v.raw, 0)}; HWY_DIAGNOSTICS(pop) +#endif // HWY_S390X_HAVE_Z15 } +#endif // HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 -template +template HWY_API VFromD ConvertTo(D /* tag */, - Vec128().MaxLanes()> v) { + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) && (!HWY_COMPILER_CLANG || !HWY_S390X_HAVE_Z14) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr int64_t kMinI64 = LimitsMin(); + constexpr int64_t kMaxI64 = LimitsMax(); + return Dup128VecFromValues(D(), + (v.raw[0] >= -9223372036854775808.0) + ? ((v.raw[0] < 9223372036854775808.0) + ? static_cast(v.raw[0]) + : kMaxI64) + : ((v.raw[0] < 0) ? kMinI64 : 0LL), + (v.raw[1] >= -9223372036854775808.0) + ? ((v.raw[1] < 9223372036854775808.0) + ? static_cast(v.raw[1]) + : kMaxI64) + : ((v.raw[1] < 0) ? kMinI64 : 0LL)); + } +#endif + + // Use inline assembly to avoid undefined behavior if v[i] is not within the + // range of an int64_t + __vector signed long long raw_result; +#if HWY_S390X_HAVE_Z14 + __asm__("vcgdb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); +#else + __asm__("xvcvdpsxds %x0,%x1" + : "=wa"(raw_result) + : "wa"(detail::VsxF2INormalizeSrcVals(v).raw)); +#endif + return VFromD{raw_result}; +} + +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D du32, + Vec128().MaxLanes()> v) { + const Rebind du64; + return DemoteTo(du32, PromoteTo(du64, v)); +} +template +HWY_API VFromD ConvertTo(D du32, + Vec128().MaxLanes()> v) { + const RepartitionToWide du64; + return OrderedDemote2To(du32, PromoteLowerTo(du64, v), + PromoteUpperTo(du64, v)); +} +#else // Z15 or VSX +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr uint32_t kMaxU32 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= 0.0f) + ? ((v.raw[0] < 4294967296.0f) ? static_cast(v.raw[0]) + : kMaxU32) + : 0, + (v.raw[1] >= 0.0f) + ? ((v.raw[1] < 4294967296.0f) ? static_cast(v.raw[1]) + : kMaxU32) + : 0, + (v.raw[2] >= 0.0f) + ? ((v.raw[2] < 4294967296.0f) ? static_cast(v.raw[2]) + : kMaxU32) + : 0, + (v.raw[3] >= 0.0f) + ? ((v.raw[3] < 4294967296.0f) ? static_cast(v.raw[3]) + : kMaxU32) + : 0); + } +#endif + +#if HWY_S390X_HAVE_Z15 + // Use inline assembly on Z15 to avoid undefined behavior if v[i] is not in + // the range of an uint32_t + __vector unsigned int raw_result; + __asm__("vclfeb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else // VSX HWY_DIAGNOSTICS(push) #if HWY_COMPILER_CLANG HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") #endif - return VFromD{vec_ctu(ZeroIfNegative(v).raw, 0)}; + VFromD result{vec_ctu(v.raw, 0)}; HWY_DIAGNOSTICS(pop) + return result; +#endif // HWY_S390X_HAVE_Z15 } +#endif // HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 -template -HWY_API Vec128 NearestInt(Vec128 v) { +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { HWY_DIAGNOSTICS(push) #if HWY_COMPILER_CLANG HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") #endif - return Vec128{vec_cts(vec_round(v.raw), 0)}; - HWY_DIAGNOSTICS(pop) + +#if defined(__OPTIMIZE__) && (!HWY_COMPILER_CLANG || !HWY_S390X_HAVE_Z14) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr uint64_t kMaxU64 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= 0.0) ? ((v.raw[0] < 18446744073709551616.0) + ? static_cast(v.raw[0]) + : kMaxU64) + : 0, + (v.raw[1] >= 0.0) ? ((v.raw[1] < 18446744073709551616.0) + ? static_cast(v.raw[1]) + : kMaxU64) + : 0); + } +#endif + + // Use inline assembly to avoid undefined behavior if v[i] is not within the + // range of an uint64_t + __vector unsigned long long raw_result; +#if HWY_S390X_HAVE_Z14 + __asm__("vclgdb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); +#else // VSX + __asm__("xvcvdpuxds %x0,%x1" + : "=wa"(raw_result) + : "wa"(detail::VsxF2INormalizeSrcVals(v).raw)); +#endif + return VFromD{raw_result}; } // ------------------------------ Floating-point rounding (ConvertTo) @@ -3555,7 +4829,24 @@ HWY_API Vec128 Round(Vec128 v) { template HWY_API Vec128 Round(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return Vec128{vec_round(v.raw)}; +#else return Vec128{vec_rint(v.raw)}; +#endif +} + +template +HWY_API Vec128, N> NearestInt(Vec128 v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Round(v)); +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, Round(v)); } // Toward zero, aka truncate @@ -3613,7 +4904,7 @@ HWY_API Mask128 IsFinite(Vec128 v) { // ================================================== CRYPTO -#if !defined(HWY_DISABLE_PPC8_CRYPTO) +#if !HWY_S390X_HAVE_Z14 && !defined(HWY_DISABLE_PPC8_CRYPTO) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES @@ -3918,6 +5209,15 @@ struct CompressIsPartition { enum { value = (sizeof(T) != 1) }; }; +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits128(d, mask_bits); +} + // ------------------------------ StoreMaskBits namespace detail { @@ -3930,37 +5230,45 @@ HWY_INLINE uint64_t ExtractSignBits(Vec128 sign_bits, // clang POWER8 and 9 targets appear to differ in their return type of // vec_vbpermq: unsigned or signed, so cast to avoid a warning. using VU64 = detail::Raw128::type; +#if HWY_S390X_HAVE_Z14 + const Vec128 extracted{ + reinterpret_cast(vec_bperm_u128(sign_bits.raw, bit_shuffle))}; +#else const Vec128 extracted{ reinterpret_cast(vec_vbpermq(sign_bits.raw, bit_shuffle))}; +#endif return extracted.raw[HWY_IS_LITTLE_ENDIAN]; } -#endif // !HWY_PPC_HAVE_10 +#endif // !HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, Mask128 mask) { const DFromM d; const Repartition du8; const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + #if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN return static_cast(vec_extractm(sign_bits.raw)); -#else +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 const __vector unsigned char kBitShuffle = {120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0}; return ExtractSignBits(sign_bits, kBitShuffle); -#endif // HWY_PPC_HAVE_10 +#endif // HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, Mask128 mask) { const DFromM d; + const RebindToUnsigned du; + const Repartition du8; const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); #if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN - const RebindToUnsigned du; return static_cast(vec_extractm(BitCast(du, sign_bits).raw)); -#else +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; #if HWY_IS_LITTLE_ENDIAN const __vector unsigned char kBitShuffle = { 112, 96, 80, 64, 48, 32, 16, 0, 128, 128, 128, 128, 128, 128, 128, 128}; @@ -3975,12 +5283,15 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, Mask128 mask) { template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, Mask128 mask) { const DFromM d; + const RebindToUnsigned du; + const Repartition du8; const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + #if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN - const RebindToUnsigned du; return static_cast(vec_extractm(BitCast(du, sign_bits).raw)); -#else +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; #if HWY_IS_LITTLE_ENDIAN const __vector unsigned char kBitShuffle = {96, 64, 32, 0, 128, 128, 128, 128, 128, 128, 128, 128, @@ -3997,12 +5308,15 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, Mask128 mask) { template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, Mask128 mask) { const DFromM d; + const RebindToUnsigned du; + const Repartition du8; const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + #if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN - const RebindToUnsigned du; return static_cast(vec_extractm(BitCast(du, sign_bits).raw)); -#else +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; #if HWY_IS_LITTLE_ENDIAN const __vector unsigned char kBitShuffle = {64, 0, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -4076,31 +5390,32 @@ HWY_API size_t StoreMaskBits(D /*d*/, MFromD mask, uint8_t* bits) { template HWY_API bool AllFalse(D d, MFromD mask) { const RebindToUnsigned du; - return static_cast(vec_all_eq(RebindMask(du, mask).raw, Zero(du).raw)); + return static_cast( + vec_all_eq(VecFromMask(du, RebindMask(du, mask)).raw, Zero(du).raw)); } template HWY_API bool AllTrue(D d, MFromD mask) { const RebindToUnsigned du; using TU = TFromD; - return static_cast( - vec_all_eq(RebindMask(du, mask).raw, Set(du, hwy::LimitsMax()).raw)); + return static_cast(vec_all_eq(VecFromMask(du, RebindMask(du, mask)).raw, + Set(du, hwy::LimitsMax()).raw)); } template HWY_API bool AllFalse(D d, MFromD mask) { const Full128> d_full; constexpr size_t kN = MaxLanes(d); - return AllFalse(d_full, MFromD{ - vec_and(mask.raw, FirstN(d_full, kN).raw)}); + return AllFalse(d_full, + And(MFromD{mask.raw}, FirstN(d_full, kN))); } template HWY_API bool AllTrue(D d, MFromD mask) { const Full128> d_full; constexpr size_t kN = MaxLanes(d); - return AllTrue(d_full, MFromD{ - vec_or(mask.raw, Not(FirstN(d_full, kN)).raw)}); + return AllTrue( + d_full, Or(MFromD{mask.raw}, Not(FirstN(d_full, kN)))); } template @@ -4222,7 +5537,7 @@ HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { __asm__("xxgenpcvbm %x0, %1, %2" : "=wa"(idx) : "v"(mask.raw), "i"(kGenPcvmMode)); - return VFromD{idx}; + return VFromD{idx}; } template HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { @@ -4235,7 +5550,7 @@ HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { __asm__("xxgenpcvhm %x0, %1, %2" : "=wa"(idx) : "v"(mask.raw), "i"(kGenPcvmMode)); - return VFromD{idx}; + return VFromD{idx}; } template HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { @@ -4248,7 +5563,7 @@ HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { __asm__("xxgenpcvwm %x0, %1, %2" : "=wa"(idx) : "v"(mask.raw), "i"(kGenPcvmMode)); - return VFromD{idx}; + return VFromD{idx}; } #endif @@ -4821,7 +6136,7 @@ HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); -#if HWY_PPC_HAVE_9 +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 StoreN(compressed, d, unaligned, count); #else BlendedStore(compressed, FirstN(d, count), d, unaligned); @@ -4939,7 +6254,18 @@ HWY_INLINE V Per128BitBlkRevLanesOnBe(V v) { template HWY_INLINE V I128Subtract(V a, V b) { -#if defined(__SIZEOF_INT128__) +#if HWY_S390X_HAVE_Z14 +#if HWY_COMPILER_CLANG + // Workaround for bug in vec_sub_u128 in Clang vecintrin.h + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const V diff_i128{reinterpret_cast>::type>( + reinterpret_cast(a.raw) - reinterpret_cast(b.raw))}; +#else // !HWY_COMPILER_CLANG + const V diff_i128{reinterpret_cast>::type>( + vec_sub_u128(reinterpret_cast<__vector unsigned char>(a.raw), + reinterpret_cast<__vector unsigned char>(b.raw)))}; +#endif // HWY_COMPILER_CLANG +#elif defined(__SIZEOF_INT128__) using VU128 = __vector unsigned __int128; const V diff_i128{reinterpret_cast>::type>( vec_sub(reinterpret_cast(a.raw), reinterpret_cast(b.raw)))}; @@ -5067,84 +6393,133 @@ HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); } -// ------------------------------ Reductions - +// ------------------------------ SumsOf2 and SumsOf4 namespace detail { -// N=1 for any T: no-op -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v) { - return v; +#if !HWY_S390X_HAVE_Z14 +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4sbs(D d, __vector signed char a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const int64_t sum0 = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[0]); + const int64_t sum1 = + static_cast(a[4]) + static_cast(a[5]) + + static_cast(a[6]) + static_cast(a[7]) + + static_cast(b[1]); + const int64_t sum2 = + static_cast(a[8]) + static_cast(a[9]) + + static_cast(a[10]) + static_cast(a[11]) + + static_cast(b[2]); + const int64_t sum3 = + static_cast(a[12]) + static_cast(a[13]) + + static_cast(a[14]) + static_cast(a[15]) + + static_cast(b[3]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + const int32_t sign2 = static_cast(sum2 >> 63); + const int32_t sign3 = static_cast(sum3 >> 63); + using Raw = typename detail::Raw128::type; + return BitCast( + d, + VFromD{Raw{ + (sign0 == (sum0 >> 31)) ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF), + (sign2 == (sum2 >> 31)) ? static_cast(sum2) + : static_cast(sign2 ^ 0x7FFFFFFF), + (sign3 == (sum3 >> 31)) + ? static_cast(sum3) + : static_cast(sign3 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4sbs(a, b)}); + } } -// u32/i32/f32: - -// N=2 -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v10) { - // NOTE: AltivecVsum2sws cannot be used here as AltivecVsum2sws - // computes the signed saturated sum of the lanes. - return v10 + Shuffle2301(v10); -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v10) { - return Min(v10, Shuffle2301(v10)); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v10) { - return Max(v10, Shuffle2301(v10)); +// Casts nominally uint32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4ubs(D d, __vector unsigned char a, + __vector unsigned int b) { + const Repartition du32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const uint64_t sum0 = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[0]); + const uint64_t sum1 = + static_cast(a[4]) + static_cast(a[5]) + + static_cast(a[6]) + static_cast(a[7]) + + static_cast(b[1]); + const uint64_t sum2 = + static_cast(a[8]) + static_cast(a[9]) + + static_cast(a[10]) + static_cast(a[11]) + + static_cast(b[2]); + const uint64_t sum3 = + static_cast(a[12]) + static_cast(a[13]) + + static_cast(a[14]) + static_cast(a[15]) + + static_cast(b[3]); + return BitCast( + d, + VFromD{(__vector unsigned int){ + static_cast(sum0 <= 0xFFFFFFFFu ? sum0 : 0xFFFFFFFFu), + static_cast(sum1 <= 0xFFFFFFFFu ? sum1 : 0xFFFFFFFFu), + static_cast(sum2 <= 0xFFFFFFFFu ? sum2 : 0xFFFFFFFFu), + static_cast(sum3 <= 0xFFFFFFFFu ? sum3 + : 0xFFFFFFFFu)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4ubs(a, b)}); + } } -// N=4 (full) -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v3210) { - // NOTE: AltivecVsumsws cannot be used here as AltivecVsumsws - // computes the signed saturated sum of the lanes. - const Vec128 v1032 = Shuffle1032(v3210); - const Vec128 v31_20_31_20 = v3210 + v1032; - const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); - return v20_31_20_31 + v31_20_31_20; -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v3210) { - const Vec128 v1032 = Shuffle1032(v3210); - const Vec128 v31_20_31_20 = Min(v3210, v1032); - const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); - return Min(v20_31_20_31, v31_20_31_20); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v3210) { - const Vec128 v1032 = Shuffle1032(v3210); - const Vec128 v31_20_31_20 = Max(v3210, v1032); - const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); - return Max(v20_31_20_31, v31_20_31_20); -} +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum2sws(D d, __vector signed int a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + const Repartition du64; + constexpr int kDestLaneOffset = HWY_IS_BIG_ENDIAN; + if (IsConstantRawAltivecVect(a) && __builtin_constant_p(b[kDestLaneOffset]) && + __builtin_constant_p(b[kDestLaneOffset + 2])) { + const int64_t sum0 = static_cast(a[0]) + + static_cast(a[1]) + + static_cast(b[kDestLaneOffset]); + const int64_t sum1 = static_cast(a[2]) + + static_cast(a[3]) + + static_cast(b[kDestLaneOffset + 2]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + return BitCast(d, VFromD{(__vector unsigned long long){ + (sign0 == (sum0 >> 31)) + ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) + ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + __vector signed int sum; -// u64/i64/f64: + // Inline assembly is used for vsum2sws to avoid unnecessary shuffling + // on little-endian PowerPC targets as the result of the vsum2sws + // instruction will already be in the correct lanes on little-endian + // PowerPC targets. + __asm__("vsum2sws %0,%1,%2" : "=v"(sum) : "v"(a), "v"(b)); -// N=2 (full) -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v10) { - const Vec128 v01 = Shuffle01(v10); - return v10 + v01; -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v10) { - const Vec128 v01 = Shuffle01(v10); - return Min(v10, v01); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v10) { - const Vec128 v01 = Shuffle01(v10); - return Max(v10, v01); + return BitCast(d, VFromD{sum}); + } } // Casts nominally int32_t result to D. @@ -5238,275 +6613,440 @@ HWY_INLINE Vec128 AltivecU16SumsOf2(Vec128 v) { return AltivecVsum4shs(di32, Xor(BitCast(di16, v), Set(di16, -32768)).raw, Set(di32, 65536).raw); } +#endif // !HWY_S390X_HAVE_Z14 + +// U16->U32 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum4(v.raw, Zero(d).raw)}; +#else + return BitCast(dw, AltivecU16SumsOf2(v)); +#endif +} + +// I16->I32 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw, SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<2>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw, int32_t{-65536}); +#else + return AltivecVsum4shs(dw, v.raw, Zero(dw).raw); +#endif +} + +#if HWY_S390X_HAVE_Z14 +// U32->U64 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + return VFromD{vec_sum2(v.raw, Zero(d).raw)}; +} + +// I32->I64 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned du; + + return BitCast(dw, SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<4>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw, int64_t{-4294967296LL}); +} +#endif + +// U8->U32 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWideX2 dw2; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum4(v.raw, Zero(d).raw)}; +#else + return AltivecVsum4ubs(dw2, v.raw, Zero(dw2).raw); +#endif +} + +// I8->I32 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWideX2 dw2; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw2, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw2, int32_t{-512}); +#else + return AltivecVsum4sbs(dw2, v.raw, Zero(dw2).raw); +#endif +} + +// U16->U64 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToWide dw2; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum2(v.raw, Zero(d).raw)}; +#else + const RebindToSigned dw_i; + return AltivecVsum2sws(dw2, BitCast(dw_i, SumsOf2(v)).raw, Zero(dw_i).raw); +#endif +} + +// I16->I64 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToWide dw2; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw2, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<2>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw2, int64_t{-131072}); +#else // VSX + const auto sums_of_4_in_lo32 = + AltivecVsum2sws(dw, SumsOf2(v).raw, Zero(dw).raw); + +#if HWY_IS_LITTLE_ENDIAN + return PromoteEvenTo(dw2, sums_of_4_in_lo32); +#else + return PromoteOddTo(dw2, sums_of_4_in_lo32); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +} // namespace detail + +// ------------------------------ SumOfLanes + +// We define SumOfLanes for 8/16-bit types (and I32/U32/I64/U64 on Z14/Z15/Z16); +// enable generic for the rest. +#undef HWY_IF_SUM_OF_LANES_D +#if HWY_S390X_HAVE_Z14 +#define HWY_IF_SUM_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1), HWY_IF_FLOAT3264_D(D) +#else +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8)) +#endif + +#if HWY_S390X_HAVE_Z14 +namespace detail { -HWY_API Vec32 SumOfLanes(Vec32 v) { +#if HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_s390_vsumqf) && \ + HWY_HAS_BUILTIN(__builtin_s390_vsumqg) +// Workaround for bug in vec_sum_u128 in Clang vecintrin.h +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const DFromV d; + const RebindToUnsigned du; + const VU128 sum = {__builtin_s390_vsumqf(BitCast(du, v).raw, Zero(du).raw)}; + return Vec128{reinterpret_cast::type>(sum)}; +} +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const DFromV d; + const RebindToUnsigned du; + const VU128 sum = {__builtin_s390_vsumqg(BitCast(du, v).raw, Zero(du).raw)}; + return Vec128{reinterpret_cast::type>(sum)}; +} +#else +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, Vec128{vec_sum_u128(BitCast(du, v).raw, Zero(du).raw)}); +} +#endif + +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D /*d64*/, VFromD v) { + return Broadcast<1>(detail::SumOfU32OrU64LanesAsU128(v)); +} +#endif + +template +HWY_API Vec32 SumOfLanes(D du16, Vec32 v) { constexpr int kSumLaneIdx = HWY_IS_BIG_ENDIAN; - DFromV du16; - return Broadcast(BitCast(du16, AltivecU16SumsOf2(v))); + return Broadcast( + BitCast(du16, detail::SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<2>(), v))); } -HWY_API Vec64 SumOfLanes(Vec64 v) { +template +HWY_API Vec64 SumOfLanes(D du16, Vec64 v) { constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; - const Full64 du16; - const auto zero = Zero(Full128()); return Broadcast( - AltivecVsum2sws(du16, AltivecU16SumsOf2(v).raw, zero.raw)); + BitCast(du16, detail::SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<2>(), v))); } -HWY_API Vec128 SumOfLanes(Vec128 v) { +template +HWY_API Vec128 SumOfLanes(D du16, Vec128 v) { constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; - const Full128 du16; +#if HWY_S390X_HAVE_Z14 + return Broadcast( + BitCast(du16, detail::SumOfU32OrU64LanesAsU128(detail::SumsOf4( + hwy::UnsignedTag(), hwy::SizeTag<2>(), v)))); +#else // VSX const auto zero = Zero(Full128()); return Broadcast( - AltivecVsumsws(du16, AltivecU16SumsOf2(v).raw, zero.raw)); + detail::AltivecVsumsws(du16, detail::AltivecU16SumsOf2(v).raw, zero.raw)); +#endif } -HWY_API Vec32 SumOfLanes(Vec32 v) { +template +HWY_API Vec32 SumOfLanes(D di16, Vec32 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else constexpr int kSumLaneIdx = HWY_IS_BIG_ENDIAN; - const Full32 di16; - const auto zero = Zero(Full128()); - return Broadcast(AltivecVsum4shs(di16, v.raw, zero.raw)); + return Broadcast( + BitCast(di16, detail::SumsOf2(hwy::SignedTag(), hwy::SizeTag<2>(), v))); +#endif } -HWY_API Vec64 SumOfLanes(Vec64 v) { +template +HWY_API Vec64 SumOfLanes(D di16, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; - const Full128 di32; - const Full64 di16; - const auto zero = Zero(di32); - return Broadcast(AltivecVsum2sws( - di16, AltivecVsum4shs(di32, v.raw, zero.raw).raw, zero.raw)); + return Broadcast( + BitCast(di16, detail::SumsOf4(hwy::SignedTag(), hwy::SizeTag<2>(), v))); +#endif } -HWY_API Vec128 SumOfLanes(Vec128 v) { +template +HWY_API Vec128 SumOfLanes(D di16, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; - const Full128 di16; const Full128 di32; const auto zero = Zero(di32); - return Broadcast(AltivecVsumsws( - di16, AltivecVsum4shs(di32, v.raw, zero.raw).raw, zero.raw)); + return Broadcast(detail::AltivecVsumsws( + di16, detail::AltivecVsum4shs(di32, v.raw, zero.raw).raw, zero.raw)); +#endif } -// u8, N=2, N=4, N=8, N=16: -HWY_API Vec16 SumOfLanes(Vec16 v) { +template +HWY_API Vec32 SumOfLanes(D du8, Vec32 v) { constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; - const Full16 du8; - const Full16 du16; - const Twice dt_u8; - const Twice dt_u16; - const Full128 du32; - return LowerHalf(Broadcast(AltivecVsum4ubs( - dt_u8, BitCast(dt_u8, Combine(dt_u16, Zero(du16), BitCast(du16, v))).raw, - Zero(du32).raw))); + return Broadcast( + BitCast(du8, detail::SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v))); } -HWY_API Vec32 SumOfLanes(Vec32 v) { - constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; - const Full128 du32; - const Full32 du8; - return Broadcast(AltivecVsum4ubs(du8, v.raw, Zero(du32).raw)); +template +HWY_API Vec16 SumOfLanes(D du8, Vec16 v) { + const Twice dt_u8; + return LowerHalf(du8, SumOfLanes(dt_u8, Combine(dt_u8, Zero(du8), v))); } -HWY_API Vec64 SumOfLanes(Vec64 v) { +template +HWY_API Vec64 SumOfLanes(D du8, Vec64 v) { constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; - const Full64 du8; return Broadcast(BitCast(du8, SumsOf8(v))); } -HWY_API Vec128 SumOfLanes(Vec128 v) { +template +HWY_API Vec128 SumOfLanes(D du8, Vec128 v) { constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 15; +#if HWY_S390X_HAVE_Z14 + return Broadcast( + BitCast(du8, detail::SumOfU32OrU64LanesAsU128(detail::SumsOf4( + hwy::UnsignedTag(), hwy::SizeTag<1>(), v)))); +#else const Full128 du32; const RebindToSigned di32; - const Full128 du8; const Vec128 zero = Zero(du32); - return Broadcast( - AltivecVsumsws(du8, AltivecVsum4ubs(di32, v.raw, zero.raw).raw, - BitCast(di32, zero).raw)); + return Broadcast(detail::AltivecVsumsws( + du8, detail::AltivecVsum4ubs(di32, v.raw, zero.raw).raw, + BitCast(di32, zero).raw)); +#endif } -HWY_API Vec16 SumOfLanes(Vec16 v) { +template +HWY_API Vec32 SumOfLanes(D di8, Vec32 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; - - const Full128 du16; - const Repartition di32; - const Repartition di8; - const Vec128 zzvv = BitCast( - di8, InterleaveLower(BitCast(du16, Vec128{v.raw}), Zero(du16))); - return Vec16{ - Broadcast(AltivecVsum4sbs(di8, zzvv.raw, Zero(di32).raw)) - .raw}; + return Broadcast( + BitCast(di8, detail::SumsOf4(hwy::SignedTag(), hwy::SizeTag<1>(), v))); +#endif } -HWY_API Vec32 SumOfLanes(Vec32 v) { - constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; - const Full32 di8; - const Vec128 zero = Zero(Full128()); - return Broadcast(AltivecVsum4sbs(di8, v.raw, zero.raw)); +template +HWY_API Vec16 SumOfLanes(D di8, Vec16 v) { + const Twice dt_i8; + return LowerHalf(di8, SumOfLanes(dt_i8, Combine(dt_i8, Zero(di8), v))); } -HWY_API Vec64 SumOfLanes(Vec64 v) { +template +HWY_API Vec64 SumOfLanes(D di8, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; - const Full128 di32; - const Vec128 zero = Zero(di32); - const Full64 di8; - return Broadcast(AltivecVsum2sws( - di8, AltivecVsum4sbs(di32, v.raw, zero.raw).raw, zero.raw)); + return Broadcast(BitCast(di8, SumsOf8(v))); +#endif } -HWY_API Vec128 SumOfLanes(Vec128 v) { +template +HWY_API Vec128 SumOfLanes(D di8, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 15; - const Full128 di8; const Full128 di32; const Vec128 zero = Zero(di32); - return Broadcast(AltivecVsumsws( - di8, AltivecVsum4sbs(di32, v.raw, zero.raw).raw, zero.raw)); + return Broadcast(detail::AltivecVsumsws( + di8, detail::AltivecVsum4sbs(di32, v.raw, zero.raw).raw, zero.raw)); +#endif } -template -HWY_API Vec128 MaxOfLanes(Vec128 v) { - const DFromV d; - const RepartitionToWide d16; - const RepartitionToWide d32; - Vec128 vm = Max(v, Reverse2(d, v)); - vm = Max(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); - vm = Max(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); - if (N > 8) { - const RepartitionToWide d64; - vm = Max(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); - } - return vm; +#if HWY_S390X_HAVE_Z14 +template +HWY_API VFromD SumOfLanes(D d32, VFromD v) { + const RebindToUnsigned du32; + return Broadcast<1>( + BitCast(d32, detail::SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<4>(), + BitCast(du32, v)))); } -template -HWY_API Vec128 MinOfLanes(Vec128 v) { - const DFromV d; - const RepartitionToWide d16; - const RepartitionToWide d32; - Vec128 vm = Min(v, Reverse2(d, v)); - vm = Min(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); - vm = Min(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); - if (N > 8) { - const RepartitionToWide d64; - vm = Min(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); - } - return vm; +template +HWY_API VFromD SumOfLanes(D /*d32*/, VFromD v) { + return Broadcast<3>(detail::SumOfU32OrU64LanesAsU128(v)); } +#endif -template -HWY_API Vec128 MaxOfLanes(Vec128 v) { - const DFromV d; - const RepartitionToWide d16; - const RepartitionToWide d32; - Vec128 vm = Max(v, Reverse2(d, v)); - vm = Max(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); - vm = Max(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); - if (N > 8) { - const RepartitionToWide d64; - vm = Max(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); - } - return vm; -} +// generic_ops defines MinOfLanes and MaxOfLanes. -template -HWY_API Vec128 MinOfLanes(Vec128 v) { - const DFromV d; - const RepartitionToWide d16; - const RepartitionToWide d32; - Vec128 vm = Min(v, Reverse2(d, v)); - vm = Min(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); - vm = Min(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); - if (N > 8) { - const RepartitionToWide d64; - vm = Min(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); - } - return vm; -} +// ------------------------------ ReduceSum for N=4 I8/U8 -template -HWY_API Vec128 MinOfLanes(Vec128 v) { - const Simd d; - const RepartitionToWide d32; -#if HWY_IS_LITTLE_ENDIAN - const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); - const auto odd = ShiftRight<16>(BitCast(d32, v)); -#else - const auto even = ShiftRight<16>(BitCast(d32, v)); - const auto odd = And(BitCast(d32, v), Set(d32, 0xFFFF)); -#endif - const auto min = MinOfLanes(Min(even, odd)); - // Also broadcast into odd lanes on little-endian and into even lanes - // on big-endian - return Vec128{vec_pack(min.raw, min.raw)}; -} -template -HWY_API Vec128 MinOfLanes(Vec128 v) { - const Simd d; - const RepartitionToWide d32; - // Sign-extend -#if HWY_IS_LITTLE_ENDIAN - const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); - const auto odd = ShiftRight<16>(BitCast(d32, v)); +// GetLane(SumsOf4(v)) is more efficient on PPC/Z14 than the default N=4 +// I8/U8 ReduceSum implementation in generic_ops-inl.h +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 #else - const auto even = ShiftRight<16>(BitCast(d32, v)); - const auto odd = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); +#define HWY_NATIVE_REDUCE_SUM_4_UI8 #endif - const auto min = MinOfLanes(Min(even, odd)); - // Also broadcast into odd lanes on little-endian and into even lanes - // on big-endian - return Vec128{vec_pack(min.raw, min.raw)}; + +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return static_cast>(GetLane(SumsOf4(v))); } -template -HWY_API Vec128 MaxOfLanes(Vec128 v) { - const Simd d; - const RepartitionToWide d32; -#if HWY_IS_LITTLE_ENDIAN - const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); - const auto odd = ShiftRight<16>(BitCast(d32, v)); +// ------------------------------ BitShuffle + +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE #else - const auto even = ShiftRight<16>(BitCast(d32, v)); - const auto odd = And(BitCast(d32, v), Set(d32, 0xFFFF)); +#define HWY_NATIVE_BITSHUFFLE #endif - const auto max = MaxOfLanes(Max(even, odd)); - // Also broadcast into odd lanes. - return Vec128{vec_pack(max.raw, max.raw)}; -} -template -HWY_API Vec128 MaxOfLanes(Vec128 v) { - const Simd d; - const RepartitionToWide d32; - // Sign-extend + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Repartition du8; + + const Full128> d_full_u64; + const Full128> d_full_u8; + + using RawVU64 = __vector unsigned long long; + +#if HWY_PPC_HAVE_9 + #if HWY_IS_LITTLE_ENDIAN - const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); - const auto odd = ShiftRight<16>(BitCast(d32, v)); + (void)d_full_u64; + auto bit_idx = ResizeBitCast(d_full_u8, idx); #else - const auto even = ShiftRight<16>(BitCast(d32, v)); - const auto odd = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + auto bit_idx = + BitCast(d_full_u8, ReverseLaneBytes(ResizeBitCast(d_full_u64, idx))); #endif - const auto max = MaxOfLanes(Max(even, odd)); - // Also broadcast into odd lanes on little-endian and into even lanes - // on big-endian - return Vec128{vec_pack(max.raw, max.raw)}; -} -} // namespace detail + bit_idx = Xor(bit_idx, Set(d_full_u8, uint8_t{0x3F})); -// Supported for u/i/f 32/64. Returns the same value in each lane. -template -HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { - return detail::SumOfLanes(v); -} -template -HWY_API TFromD ReduceSum(D /* tag */, VFromD v) { - return GetLane(detail::SumOfLanes(v)); -} -template -HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { - return detail::MinOfLanes(v); -} -template -HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { - return detail::MaxOfLanes(v); + return BitCast(d64, VFromD{reinterpret_cast( + vec_bperm(BitCast(du64, v).raw, bit_idx.raw))}); +#else // !HWY_PPC_HAVE_9 + +#if HWY_IS_LITTLE_ENDIAN + const auto bit_idx_xor_mask = BitCast( + d_full_u8, Dup128VecFromValues(d_full_u64, uint64_t{0x7F7F7F7F7F7F7F7Fu}, + uint64_t{0x3F3F3F3F3F3F3F3Fu})); + const auto bit_idx = Xor(ResizeBitCast(d_full_u8, idx), bit_idx_xor_mask); + constexpr int kBitShufResultByteShrAmt = 8; +#else + const auto bit_idx_xor_mask = BitCast( + d_full_u8, Dup128VecFromValues(d_full_u64, uint64_t{0x3F3F3F3F3F3F3F3Fu}, + uint64_t{0x7F7F7F7F7F7F7F7Fu})); + const auto bit_idx = + Xor(BitCast(d_full_u8, ReverseLaneBytes(ResizeBitCast(d_full_u64, idx))), + bit_idx_xor_mask); + constexpr int kBitShufResultByteShrAmt = 6; +#endif + +#if HWY_S390X_HAVE_Z14 + const VFromD bit_shuf_result{reinterpret_cast( + vec_bperm_u128(BitCast(du8, v).raw, bit_idx.raw))}; +#elif defined(__SIZEOF_INT128__) + using RawVU128 = __vector unsigned __int128; + const VFromD bit_shuf_result{reinterpret_cast( + vec_vbpermq(reinterpret_cast(v.raw), bit_idx.raw))}; +#else + using RawVU128 = __vector unsigned char; + const VFromD bit_shuf_result{reinterpret_cast( + vec_vbpermq(reinterpret_cast(v.raw), bit_idx.raw))}; +#endif + + return ResizeBitCast( + d64, PromoteTo(d_full_u64, + ResizeBitCast( + Rebind(), + CombineShiftRightBytes( + d_full_u64, bit_shuf_result, bit_shuf_result)))); +#endif // HWY_PPC_HAVE_9 } // ------------------------------ Lt128 @@ -5672,7 +7212,20 @@ HWY_API V Max128Upper(D d, const V a, const V b) { template HWY_API V LeadingZeroCount(V v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const RebindToUnsigned du; + +#if HWY_COMPILER_GCC_ACTUAL && defined(__OPTIMIZE__) + // Work around for GCC compiler bug in vec_cnttz on Z14/Z15 if v[i] is a + // constant + __asm__("" : "+v"(v.raw)); +#endif + + return BitCast(d, VFromD{vec_cntlz(BitCast(du, v).raw)}); +#else return V{vec_cntlz(v.raw)}; +#endif } template @@ -5682,14 +7235,27 @@ HWY_API V HighestSetBitIndex(V v) { return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); } -#if HWY_PPC_HAVE_9 +#if HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 template HWY_API V TrailingZeroCount(V v) { #if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 return V{vec_vctz(v.raw)}; #else - return V{vec_cnttz(v.raw)}; +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const RebindToUnsigned du; + +#if HWY_COMPILER_GCC_ACTUAL && defined(__OPTIMIZE__) + // Work around for GCC compiler bug in vec_cnttz on Z14/Z15 if v[i] is a + // constant + __asm__("" : "+v"(v.raw)); #endif + + return BitCast(d, VFromD{vec_cnttz(BitCast(du, v).raw)}); +#else + return V{vec_cnttz(v.raw)}; +#endif // HWY_S390X_HAVE_Z14 +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 } #else template @@ -5709,6 +7275,8 @@ HWY_API V TrailingZeroCount(V v) { #undef HWY_PPC_HAVE_9 #undef HWY_PPC_HAVE_10 +#undef HWY_S390X_HAVE_Z14 +#undef HWY_S390X_HAVE_Z15 // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE diff --git a/r/src/vendor/highway/hwy/ops/rvv-inl.h b/r/src/vendor/highway/hwy/ops/rvv-inl.h index 0bde49e3..e65602c6 100644 --- a/r/src/vendor/highway/hwy/ops/rvv-inl.h +++ b/r/src/vendor/highway/hwy/ops/rvv-inl.h @@ -339,8 +339,15 @@ namespace detail { // for code folding // Full support for f16 in all ops #define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ HWY_RVV_FOREACH_F16_UNCONDITIONAL(X_MACRO, NAME, OP, LMULS) +// Only BF16 is emulated. +#define HWY_RVV_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_RVV_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) #else #define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#define HWY_RVV_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_RVV_IF_NOT_EMULATED_D(D) HWY_IF_NOT_SPECIAL_FLOAT_D(D) #endif #define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) @@ -389,15 +396,11 @@ namespace detail { // for code folding // For all combinations of SEW: #define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) #define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) #define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ @@ -409,8 +412,7 @@ namespace detail { // for code folding HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) #define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ - HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) // Assemble types for use in x-macros @@ -438,22 +440,134 @@ HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) // ------------------------------ Lanes // WARNING: we want to query VLMAX/sizeof(T), but this may actually change VL! -#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ - MLEN, NAME, OP) \ - template \ - HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ - constexpr size_t kFull = HWY_LANES(HWY_RVV_T(BASE, SEW)); \ - constexpr size_t kCap = MaxLanes(d); \ - /* If no cap, avoid generating a constant by using VLMAX. */ \ - return N == kFull ? __riscv_vsetvlmax_e##SEW##LMUL() \ - : __riscv_vsetvl_e##SEW##LMUL(kCap); \ - } \ - template \ - HWY_API size_t Capped##NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, size_t cap) { \ - /* If no cap, avoid the HWY_MIN. */ \ - return detail::IsFull(d) \ - ? __riscv_vsetvl_e##SEW##LMUL(cap) \ - : __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, MaxLanes(d))); \ + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD +// HWY_RVV_CAPPED_LANES_SPECIAL_CASES provides some additional optimizations +// to CappedLanes in non-debug builds +#define HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) \ + if (__builtin_constant_p(cap >= kMaxLanes) && (cap >= kMaxLanes)) { \ + /* If cap is known to be greater than or equal to MaxLanes(d), */ \ + /* HWY_MIN(cap, Lanes(d)) will be equal to Lanes(d) */ \ + return Lanes(d); \ + } \ + \ + if ((__builtin_constant_p((cap & (cap - 1)) == 0) && \ + ((cap & (cap - 1)) == 0)) || \ + (__builtin_constant_p(cap <= HWY_MAX(kMinLanesPerFullVec, 4)) && \ + (cap <= HWY_MAX(kMinLanesPerFullVec, 4)))) { \ + /* If cap is known to be a power of 2, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the same */ \ + /* result as HWY_MIN(cap, Lanes(d)) as kMaxLanes is a power of 2 and */ \ + /* as (cap > VLMAX && cap < 2 * VLMAX) can only be true if cap is not a */ \ + /* power of 2 since VLMAX is always a power of 2 */ \ + \ + /* If cap is known to be less than or equal to 4, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the same */ \ + /* result as HWY_MIN(cap, Lanes(d)) as HWY_MIN(cap, kMaxLanes) <= 4 is */ \ + /* true if cap <= 4 and as vsetvl(HWY_MIN(cap, kMaxLanes)) is */ \ + /* guaranteed to return the same result as HWY_MIN(cap, Lanes(d)) */ \ + /* if HWY_MIN(cap, kMaxLanes) <= 4 is true */ \ + \ + /* If cap is known to be less than or equal to kMinLanesPerFullVec, */ \ + /* then vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the */ \ + /* same result as HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= kMinLanesPerFullVec is true if */ \ + /* cap <= kMinLanesPerFullVec is true */ \ + \ + /* If cap <= HWY_MAX(kMinLanesPerFullVec, 4) is true, then either */ \ + /* cap <= 4 or cap <= kMinLanesPerFullVec must be true */ \ + \ + /* If cap <= HWY_MAX(kMinLanesPerFullVec, 4) is known to be true, */ \ + /* then vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the */ \ + /* same result as HWY_MIN(cap, Lanes(d)) */ \ + \ + /* If no cap, avoid the HWY_MIN. */ \ + return detail::IsFull(d) \ + ? __riscv_vsetvl_e##SEW##LMUL(cap) \ + : __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, kMaxLanes)); \ + } +#else +#define HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) +#endif + +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + constexpr size_t kFull = HWY_LANES(HWY_RVV_T(BASE, SEW)); \ + constexpr size_t kCap = MaxLanes(d); \ + /* If no cap, avoid generating a constant by using VLMAX. */ \ + return N == kFull ? __riscv_vsetvlmax_e##SEW##LMUL() \ + : __riscv_vsetvl_e##SEW##LMUL(kCap); \ + } \ + template \ + HWY_API size_t Capped##NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, size_t cap) { \ + /* NOTE: Section 6.3 of the RVV specification, which can be found at */ \ + /* https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc, */ \ + /* allows vsetvl to return a result less than Lanes(d) but greater than */ \ + /* or equal to ((cap + 1) / 2) if */ \ + /* (Lanes(d) > 2 && cap > HWY_MAX(Lanes(d), 4) && cap < (2 * Lanes(d))) */ \ + /* is true */ \ + \ + /* VLMAX is the number of lanes in a vector of type */ \ + /* VFromD, which is returned by */ \ + /* Lanes(DFromV>()) */ \ + \ + /* VLMAX is guaranteed to be a power of 2 under Section 2 of the RVV */ \ + /* specification */ \ + \ + /* The VLMAX of a vector of type VFromD is at least 2 as */ \ + /* the HWY_RVV target requires support for the RVV Zvl128b extension, */ \ + /* which guarantees that vectors with LMUL=1 are at least 16 bytes */ \ + \ + /* If VLMAX == 2 is true, then vsetvl(cap) is equal to HWY_MIN(cap, 2) */ \ + /* as cap == 3 is the only value such that */ \ + /* (cap > VLMAX && cap < 2 * VLMAX) if VLMAX == 2 and as */ \ + /* ((3 + 1) / 2) is equal to 2 */ \ + \ + /* If cap <= 4 is true, then vsetvl(cap) must be equal to */ \ + /* HWY_MIN(cap, VLMAX) as cap <= VLMAX is true if VLMAX >= 4 is true */ \ + /* and as vsetvl(cap) is guaranteed to be equal to HWY_MIN(cap, VLMAX) */ \ + /* if VLMAX == 2 */ \ + \ + /* We want CappedLanes(d, cap) to return Lanes(d) if cap > Lanes(d) as */ \ + /* LoadN(d, p, cap) expects to load exactly HWY_MIN(cap, Lanes(d)) */ \ + /* lanes and StoreN(v, d, p, cap) expects to store exactly */ \ + /* HWY_MIN(cap, Lanes(d)) lanes, even in the case where vsetvl returns */ \ + /* a result that is less than HWY_MIN(cap, Lanes(d)) */ \ + \ + /* kMinLanesPerFullVec is the minimum value of VLMAX for a vector of */ \ + /* type VFromD */ \ + constexpr size_t kMinLanesPerFullVec = \ + detail::ScaleByPower(16 / (SEW / 8), SHIFT); \ + /* kMaxLanes is the maximum number of lanes returned by Lanes(d) */ \ + constexpr size_t kMaxLanes = MaxLanes(d); \ + \ + HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) \ + \ + if (kMaxLanes <= HWY_MAX(kMinLanesPerFullVec, 4)) { \ + /* If kMaxLanes <= kMinLanesPerFullVec is true, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return */ \ + /* HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= kMaxLanes <= VLMAX is true if */ \ + /* kMaxLanes <= kMinLanesPerFullVec is true */ \ + \ + /* If kMaxLanes <= 4 is true, then vsetvl(HWY_MIN(cap, kMaxLanes)) is */ \ + /* guaranteed to return the same result as HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= 4 is true if kMaxLanes <= 4 is true */ \ + \ + /* If kMaxLanes <= HWY_MAX(kMinLanesPerFullVec, 4) is true, then */ \ + /* either kMaxLanes <= 4 or kMaxLanes <= kMinLanesPerFullVec must be */ \ + /* true */ \ + \ + return __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, kMaxLanes)); \ + } else { \ + /* If kMaxLanes > HWY_MAX(kMinLanesPerFullVec, 4) is true, need to */ \ + /* obtain the actual number of lanes using Lanes(d) and clamp cap to */ \ + /* the result of Lanes(d) */ \ + const size_t actual = Lanes(d); \ + return HWY_MIN(actual, cap); \ + } \ } #define HWY_RVV_LANES_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ @@ -480,18 +594,18 @@ HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL) HWY_RVV_FOREACH(HWY_RVV_LANES_VIRT, Lanes, lenb, _VIRT) -// If not already defined via HWY_RVV_FOREACH, define the overloads because -// they do not require any new instruction. -#if !HWY_HAVE_FLOAT16 -HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL) -HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_LANES_VIRT, Lanes, lenb, _VIRT) -#endif #undef HWY_RVV_LANES #undef HWY_RVV_LANES_VIRT +#undef HWY_RVV_CAPPED_LANES_SPECIAL_CASES + +template +HWY_API size_t Lanes(D /* tag*/) { + return Lanes(RebindToUnsigned()); +} -template -HWY_API size_t Lanes(Simd /* tag*/) { - return Lanes(Simd()); +template +HWY_API size_t CappedLanes(D /* tag*/, size_t cap) { + return CappedLanes(RebindToUnsigned(), cap); } // ------------------------------ Common x-macros @@ -525,10 +639,20 @@ HWY_API size_t Lanes(Simd /* tag*/) { HWY_RVV_AVL(SEW, SHIFT)); \ } +// vector = f(vector, mask, vector, vector), e.g. MaskedAddOr +#define HWY_RVV_RETV_ARGMVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_M(MLEN) m, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(m, no, a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + // mask = f(mask) -#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ - HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ - return __riscv_vm##OP##_m_b##MLEN(m, ~0ull); \ +#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return __riscv_vm##OP##_m_b##MLEN(m, HWY_RVV_AVL(SEW, SHIFT)); \ } // ================================================== INIT @@ -549,21 +673,17 @@ HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL_VIRT) // Treat bfloat16_t as int16_t (using the previously defined Set overloads); // required for Zero and VFromD. -template -decltype(Set(Simd(), 0)) Set(Simd d, - bfloat16_t arg) { - return Set(RebindToSigned(), arg.bits); +template +decltype(Set(RebindToSigned(), 0)) Set(D d, hwy::bfloat16_t arg) { + return Set(RebindToSigned(), BitCastScalar(arg)); } #if !HWY_HAVE_FLOAT16 // Otherwise already defined above. // WARNING: returns a different type than emulated bfloat16_t so that we can // implement PromoteTo overloads for both bfloat16_t and float16_t, and also -// provide a Neg(float16_t) overload that coexists with Neg(int16_t). -template -decltype(Set(Simd(), 0)) Set(Simd d, - float16_t arg) { - uint16_t bits; - CopySameSize(&arg, &bits); - return Set(RebindToUnsigned(), bits); +// provide a Neg(hwy::float16_t) overload that coexists with Neg(int16_t). +template +decltype(Set(RebindToUnsigned(), 0)) Set(D d, hwy::float16_t arg) { + return Set(RebindToUnsigned(), BitCastScalar(arg)); } #endif @@ -642,16 +762,7 @@ HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) HWY_RVV_FOREACH(HWY_RVV_EXT_VIRT, Ext, lmul_ext, _VIRT) #undef HWY_RVV_EXT_VIRT -#if !HWY_HAVE_FLOAT16 -template -VFromD Ext(D d, VFromD> v) { - const RebindToUnsigned du; - const Half duh; - return BitCast(d, Ext(du, BitCast(duh, v))); -} -#endif - -template +template VFromD Ext(D d, VFromD> v) { const RebindToUnsigned du; const Half duh; @@ -767,10 +878,10 @@ HWY_RVV_FOREACH_F(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_CAST_IF, _, reinterpret, _ALL) HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) #else -template -HWY_INLINE VFromD> BitCastFromByte( - Simd /* d */, VFromD> v) { - return BitCastFromByte(Simd(), v); +template +HWY_INLINE VFromD> BitCastFromByte( + D /* d */, VFromD> v) { + return BitCastFromByte(RebindToUnsigned(), v); } #endif @@ -781,10 +892,10 @@ HWY_INLINE VFromD> BitCastFromByte( #undef HWY_RVV_CAST_VIRT_U #undef HWY_RVV_CAST_VIRT_IF -template -HWY_INLINE VFromD> BitCastFromByte( - Simd /* d */, VFromD> v) { - return BitCastFromByte(Simd(), v); +template +HWY_INLINE VFromD> BitCastFromByte( + D d, VFromD> v) { + return BitCastFromByte(RebindToSigned(), v); } } // namespace detail @@ -942,6 +1053,35 @@ HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, SubS, sub_vx, _ALL) HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) +// ------------------------------ Neg (ReverseSubS, Xor) + +template +HWY_API V Neg(const V v) { + return detail::ReverseSubS(v, 0); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, v, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + +#if !HWY_HAVE_FLOAT16 + +template )> // hwy::float16_t +HWY_API V Neg(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} + +#endif // !HWY_HAVE_FLOAT16 + // ------------------------------ SaturatedAdd #ifdef HWY_NATIVE_I32_SATURATED_ADDSUB @@ -978,6 +1118,18 @@ HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) // ------------------------------ AverageRound +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + // Define this to opt-out of the default behavior, which is AVOID on certain // compiler versions. You can define only this to use VXRM, or define both this // and HWY_RVV_AVOID_VXRM to always avoid VXRM. @@ -1017,8 +1169,8 @@ HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) a, b, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ } -HWY_RVV_FOREACH_U08(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL) -HWY_RVV_FOREACH_U16(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_AVERAGE, AverageRound, aadd, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL) #undef HWY_RVV_RETV_AVERAGE @@ -1047,8 +1199,37 @@ HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) #undef HWY_RVV_SHIFT +// ------------------------------ RoundingShiftRight[Same] + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_ROUNDING_SHR(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL( \ + v, kBits, \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL( \ + v, static_cast(bits), \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_ROUNDING_SHR, RoundingShiftRight, ssrl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_ROUNDING_SHR, RoundingShiftRight, ssra, _ALL) + +#undef HWY_RVV_ROUNDING_SHR + // ------------------------------ SumsOf8 (ShiftRight, Add) -template +template )> HWY_API VFromD>> SumsOf8(const VU8 v) { const DFromV du8; const RepartitionToWide du16; @@ -1071,13 +1252,42 @@ HWY_API VFromD>> SumsOf8(const VU8 v) { return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); } +template )> +HWY_API VFromD>> SumsOf8(const VI8 v) { + const DFromV di8; + const RepartitionToWide di16; + const RepartitionToWide di32; + const RepartitionToWide di64; + const RebindToUnsigned du32; + const RebindToUnsigned du64; + using VI16 = VFromD; + + const VI16 vFDB97531 = ShiftRight<8>(BitCast(di16, v)); + const VI16 vECA86420 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, v))); + const VI16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VI16 sDC_zz_98_zz_54_zz_10_zz = + BitCast(di16, ShiftLeft<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VI16 sFC_xx_B8_xx_74_xx_30_xx = + Add(sFE_DC_BA_98_76_54_32_10, sDC_zz_98_zz_54_zz_10_zz); + const VI16 sB8_xx_zz_zz_30_xx_zz_zz = + BitCast(di16, ShiftLeft<32>(BitCast(du64, sFC_xx_B8_xx_74_xx_30_xx))); + const VI16 sF8_xx_xx_xx_70_xx_xx_xx = + Add(sFC_xx_B8_xx_74_xx_30_xx, sB8_xx_zz_zz_30_xx_zz_zz); + return ShiftRight<48>(BitCast(di64, sF8_xx_xx_xx_70_xx_xx_xx)); +} + // ------------------------------ RotateRight -template +template HWY_API V RotateRight(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; - return Or(ShiftRight(v), + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), ShiftLeft(v)); } @@ -1111,6 +1321,33 @@ HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) #undef HWY_RVV_SHIFT_II #undef HWY_RVV_SHIFT_VV +// ------------------------------ RoundingShr +#define HWY_RVV_ROUNDING_SHR_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + v, bits, \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_ROUNDING_SHR_VV, RoundingShr, ssrl, _ALL) + +#define HWY_RVV_ROUNDING_SHR_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + const HWY_RVV_D(uint, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT) du; \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + v, BitCast(du, bits), \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_ROUNDING_SHR_II, RoundingShr, ssra, _ALL) + +#undef HWY_RVV_ROUNDING_SHR_VV +#undef HWY_RVV_ROUNDING_SHR_II + // ------------------------------ Min namespace detail { @@ -1158,15 +1395,8 @@ HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) // ------------------------------ MulHigh -// Only for internal use (Highway only promises MulHigh for 16-bit inputs). -// Used by MulEven; vwmul does not work for m8. -namespace detail { HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) -} // namespace detail - -HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) -HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) // ------------------------------ MulFixedPoint15 @@ -1184,8 +1414,57 @@ HWY_RVV_FOREACH_I16(HWY_RVV_MUL15, MulFixedPoint15, smul, _ALL) #undef HWY_RVV_MUL15 // ------------------------------ Div +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Div, divu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Div, div, _ALL) HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Mod, remu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Mod, rem, _ALL) + +// ------------------------------ MaskedAddOr etc. + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedMinOr, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedMinOr, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMinOr, fmin, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, fmax, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedAddOr, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedAddOr, fadd, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedSubOr, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedSubOr, fsub, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedMulOr, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMulOr, fmul, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedDivOr, divu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedDivOr, div, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedDivOr, fdiv, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedModOr, remu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedModOr, rem, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedSatAddOr, saddu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedSatAddOr, sadd, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedSatSubOr, ssubu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedSatSubOr, ssub, _ALL) + // ------------------------------ ApproximateReciprocal #ifdef HWY_NATIVE_F64_APPROX_RECIP #undef HWY_NATIVE_F64_APPROX_RECIP @@ -1247,26 +1526,6 @@ HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) // vboolXX_t is a power of two divisor for vector bits. SEW=8 / LMUL=1 = 1/8th // of all bits; SEW=8 / LMUL=4 = half of all bits. -// SFINAE for mapping Simd<> to MLEN (up to 64). -#define HWY_RVV_IF_MLEN_D(D, MLEN) \ - hwy::EnableIf* = nullptr - -// Specialized for RVV instead of the generic test_util-inl.h implementation -// because more efficient, and helps implement MFromD. - -#define HWY_RVV_MASK_FALSE(SEW, SHIFT, MLEN, NAME, OP) \ - template \ - HWY_API HWY_RVV_M(MLEN) NAME(D d) { \ - return __riscv_vm##OP##_m_b##MLEN(Lanes(d)); \ - } - -HWY_RVV_FOREACH_B(HWY_RVV_MASK_FALSE, MaskFalse, clr) -#undef HWY_RVV_MASK_FALSE -#undef HWY_RVV_IF_MLEN_D - -template -using MFromD = decltype(MaskFalse(D())); - // mask = f(vector, vector) #define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ SHIFT, MLEN, NAME, OP) \ @@ -1405,11 +1664,49 @@ HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) #undef HWY_RVV_IF_THEN_ZERO_ELSE // ------------------------------ MaskFromVec + +template +using MFromD = decltype(Eq(Zero(D()), Zero(D()))); + template HWY_API MFromD> MaskFromVec(const V v) { return detail::NeS(v, 0); } +// ------------------------------ IsNegative (MFromD) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD> IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + return detail::LtS(BitCast(di, v), static_cast(0)); +} + +// ------------------------------ MaskFalse + +// For mask ops including vmclr, elements past VL are tail-agnostic and cannot +// be relied upon, so define a variant of the generic_ops-inl implementation of +// MaskFalse that ensures all bits are zero as required by mask_test. +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API MFromD MaskFalse(D d) { + const DFromV> d_full; + return MaskFromVec(Zero(d_full)); +} + // ------------------------------ RebindMask template HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { @@ -1427,10 +1724,12 @@ HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_M(MLEN) m) { \ - const RebindToSigned di; \ + /* MaskFalse requires we set all lanes for capped d and virtual LMUL. */ \ + const DFromV> d_full; \ + const RebindToSigned di; \ using TI = TFromD; \ - return BitCast( \ - d, __riscv_v##OP##_i##SEW##LMUL(Zero(di), TI{-1}, m, Lanes(d))); \ + return BitCast(d_full, __riscv_v##OP##_i##SEW##LMUL(Zero(di), TI{-1}, m, \ + Lanes(d_full))); \ } HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, VecFromMask, merge_vxm, _ALL_VIRT) @@ -1448,14 +1747,8 @@ HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { return IfThenElse(MaskFromVec(mask), yes, no); } -// ------------------------------ ZeroIfNegative -template -HWY_API V ZeroIfNegative(const V v) { - return IfThenZeroElse(detail::LtS(v, 0), v); -} - // ------------------------------ BroadcastSignBit -template +template HWY_API V BroadcastSignBit(const V v) { return ShiftRight) * 8 - 1>(v); } @@ -1464,11 +1757,7 @@ HWY_API V BroadcastSignBit(const V v) { template HWY_API V IfNegativeThenElse(V v, V yes, V no) { static_assert(IsSigned>(), "Only works for signed/float"); - const DFromV d; - const RebindToSigned di; - - MFromD m = detail::LtS(BitCast(di, v), 0); - return IfThenElse(m, yes, no); + return IfThenElse(IsNegative(v), yes, no); } // ------------------------------ FindFirstTrue @@ -1518,6 +1807,38 @@ HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) #undef HWY_RVV_COUNT_TRUE +// ------------------------------ PromoteMaskTo + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template )), + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return m; +} + +// ------------------------------ DemoteMaskTo + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template ) - 1), + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return m; +} + // ================================================== MEMORY // ------------------------------ Load @@ -1528,47 +1849,18 @@ HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ - reinterpret_cast(p), Lanes(d)); \ + detail::NativeLanePointer(p), Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL_VIRT) #undef HWY_RVV_LOAD -// There is no native BF16, treat as uint16_t. -template -HWY_API VFromD> Load(Simd d, - const bfloat16_t* HWY_RESTRICT p) { - return Load(RebindToSigned(), - reinterpret_cast(p)); -} - -template -HWY_API void Store(VFromD> v, - Simd d, bfloat16_t* HWY_RESTRICT p) { - Store(v, RebindToSigned(), - reinterpret_cast(p)); -} - -#if !HWY_HAVE_FLOAT16 // Otherwise already defined above. - -// NOTE: different type for float16_t than bfloat16_t, see Set(). -template -HWY_API VFromD> Load(Simd d, - const float16_t* HWY_RESTRICT p) { - return Load(RebindToUnsigned(), - reinterpret_cast(p)); -} - -template -HWY_API void Store(VFromD> v, - Simd d, float16_t* HWY_RESTRICT p) { - Store(v, RebindToUnsigned(), - reinterpret_cast(p)); +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, Load(du, detail::U16LanePointer(p))); } -#endif // !HWY_HAVE_FLOAT16 - // ------------------------------ LoadU template HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { @@ -1584,23 +1876,37 @@ HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_mu( \ - m, Zero(d), reinterpret_cast(p), Lanes(d)); \ + m, Zero(d), detail::NativeLanePointer(p), Lanes(d)); \ } \ template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME##Or(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_mu( \ - m, v, reinterpret_cast(p), Lanes(d)); \ + m, v, detail::NativeLanePointer(p), Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) #undef HWY_RVV_MASKED_LOAD +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, + MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD no, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, MaskedLoadOr(BitCast(du, no), RebindMask(du, m), du, + detail::U16LanePointer(p))); +} + // ------------------------------ LoadN // Native with avl is faster than the generic_ops using FirstN. @@ -1616,29 +1922,41 @@ HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t num_lanes) { \ - using T = detail::NativeLaneType; \ /* Use a tail-undisturbed load in LoadN as the tail-undisturbed load */ \ /* operation below will leave any lanes past the first */ \ /* (lowest-indexed) HWY_MIN(num_lanes, Lanes(d)) lanes unchanged */ \ return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_tu( \ - Zero(d), reinterpret_cast(p), CappedLanes(d, num_lanes)); \ + Zero(d), detail::NativeLanePointer(p), CappedLanes(d, num_lanes)); \ } \ template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME##Or( \ HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t num_lanes) { \ - using T = detail::NativeLaneType; \ /* Use a tail-undisturbed load in LoadNOr as the tail-undisturbed load */ \ /* operation below will set any lanes past the first */ \ /* (lowest-indexed) HWY_MIN(num_lanes, Lanes(d)) lanes to the */ \ /* corresponding lanes in no */ \ return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_tu( \ - no, reinterpret_cast(p), CappedLanes(d, num_lanes)); \ + no, detail::NativeLanePointer(p), CappedLanes(d, num_lanes)); \ } HWY_RVV_FOREACH(HWY_RVV_LOADN, LoadN, le, _ALL_VIRT) #undef HWY_RVV_LOADN +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast(d, LoadN(du, detail::U16LanePointer(p), num_lanes)); +} +template +HWY_API VFromD LoadNOr(VFromD v, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast( + d, LoadNOr(BitCast(du, v), du, detail::U16LanePointer(p), num_lanes)); +} + // ------------------------------ Store #define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ @@ -1647,13 +1965,18 @@ HWY_RVV_FOREACH(HWY_RVV_LOADN, LoadN, le, _ALL_VIRT) HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL(reinterpret_cast(p), \ - v, Lanes(d)); \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), v, Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) #undef HWY_RVV_STORE +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + Store(BitCast(du, v), du, detail::U16LanePointer(p)); +} + // ------------------------------ BlendedStore #define HWY_RVV_BLENDED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ @@ -1662,13 +1985,20 @@ HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_m( \ - m, reinterpret_cast(p), v, Lanes(d)); \ + m, detail::NativeLanePointer(p), v, Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_BLENDED_STORE, BlendedStore, se, _ALL_VIRT) #undef HWY_RVV_BLENDED_STORE +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + BlendedStore(BitCast(du, v), RebindMask(du, m), du, + detail::U16LanePointer(p)); +} + // ------------------------------ StoreN namespace detail { @@ -1679,13 +2009,18 @@ namespace detail { HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ - using T = detail::NativeLaneType; \ - return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL(reinterpret_cast(p), \ - v, count); \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), v, count); \ } HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) #undef HWY_RVV_STOREN +template +HWY_API void StoreN(size_t count, VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreN(count, BitCast(du, v), du, detail::U16LanePointer(p)); +} + } // namespace detail #ifdef HWY_NATIVE_STORE_N @@ -1694,13 +2029,12 @@ HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) #define HWY_NATIVE_STORE_N #endif -template , - hwy::EnableIf>>()>* = nullptr> -HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, size_t max_lanes_to_store) { - // NOTE: Need to call Lanes(d) and clamp max_lanes_to_store to Lanes(d), even - // if MaxLanes(d) >= MaxLanes(DFromV>()) is true, as it is possible - // for detail::StoreN(max_lanes_to_store, v, d, p) to store fewer than + // NOTE: Need to clamp max_lanes_to_store to Lanes(d), even if + // MaxLanes(d) >= MaxLanes(DFromV>()) is true, as it is possible for + // detail::StoreN(max_lanes_to_store, v, d, p) to store fewer than // Lanes(DFromV>()) lanes to p if // max_lanes_to_store > Lanes(DFromV>()) and // max_lanes_to_store < 2 * Lanes(DFromV>()) are both true. @@ -1709,21 +2043,7 @@ HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, // if Lanes(d) < Lanes(DFromV>()) is true, which is possible if // MaxLanes(d) < MaxLanes(DFromV>()) or // d.Pow2() < DFromV>().Pow2() is true. - const size_t N = Lanes(d); - detail::StoreN(HWY_MIN(max_lanes_to_store, N), v, d, p); -} - -// StoreN for BF16/F16 vectors -template , - hwy::EnableIf>>()>* = nullptr, - HWY_IF_SPECIAL_FLOAT(T)> -HWY_API void StoreN(VFromD v, D /*d*/, T* HWY_RESTRICT p, - size_t max_lanes_to_store) { - using TStore = TFromV>; - const Rebind d_store; - const size_t N = Lanes(d_store); - detail::StoreN(HWY_MIN(max_lanes_to_store, N), v, d_store, - reinterpret_cast(p)); + detail::StoreN(CappedLanes(d, max_lanes_to_store), v, d, p); } // ------------------------------ StoreU @@ -1747,17 +2067,16 @@ HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { #define HWY_NATIVE_SCATTER #endif -#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ - SHIFT, MLEN, NAME, OP) \ - template \ - HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ - HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ - HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ - HWY_RVV_V(int, SEW, LMUL) offset) { \ - const RebindToUnsigned du; \ - using T = detail::NativeLaneType; \ - return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ - reinterpret_cast(base), BitCast(du, offset), v, Lanes(d)); \ +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + const RebindToUnsigned du; \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(base), BitCast(du, offset), v, Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL_VIRT) #undef HWY_RVV_SCATTER @@ -1772,19 +2091,18 @@ HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, // ------------------------------ MaskedScatterIndex -#define HWY_RVV_MASKED_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ - LMULH, SHIFT, MLEN, NAME, OP) \ - template \ - HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ - HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ - HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ - HWY_RVV_V(int, SEW, LMUL) indices) { \ - const RebindToUnsigned du; \ - using T = detail::NativeLaneType; \ - constexpr size_t kBits = CeilLog2(sizeof(TFromD)); \ - return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_m( \ - m, reinterpret_cast(base), ShiftLeft(BitCast(du, indices)), \ - v, Lanes(d)); \ +#define HWY_RVV_MASKED_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) indices) { \ + const RebindToUnsigned du; \ + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_m( \ + m, detail::NativeLanePointer(base), \ + ShiftLeft(BitCast(du, indices)), v, Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_MASKED_SCATTER, MaskedScatterIndex, sux, _ALL_VIRT) #undef HWY_RVV_MASKED_SCATTER @@ -1805,9 +2123,8 @@ HWY_RVV_FOREACH(HWY_RVV_MASKED_SCATTER, MaskedScatterIndex, sux, _ALL_VIRT) const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ HWY_RVV_V(int, SEW, LMUL) offset) { \ const RebindToUnsigned du; \ - using T = detail::NativeLaneType; \ return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ - reinterpret_cast(base), BitCast(du, offset), Lanes(d)); \ + detail::NativeLanePointer(base), BitCast(du, offset), Lanes(d)); \ } HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL_VIRT) #undef HWY_RVV_GATHER @@ -1821,25 +2138,34 @@ HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, return GatherOffset(d, base, ShiftLeft(index)); } -// ------------------------------ MaskedGatherIndex +// ------------------------------ MaskedGatherIndexOr #define HWY_RVV_MASKED_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ SHIFT, MLEN, NAME, OP) \ template \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ HWY_RVV_V(int, SEW, LMUL) indices) { \ const RebindToUnsigned du; \ - using T = detail::NativeLaneType; \ + const RebindToSigned di; \ + (void)di; /* for HWY_DASSERT */ \ constexpr size_t kBits = CeilLog2(SEW / 8); \ + HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_mu( \ - m, Zero(d), reinterpret_cast(base), \ + m, no, detail::NativeLanePointer(base), \ ShiftLeft(BitCast(du, indices)), Lanes(d)); \ } -HWY_RVV_FOREACH(HWY_RVV_MASKED_GATHER, MaskedGatherIndex, lux, _ALL_VIRT) +HWY_RVV_FOREACH(HWY_RVV_MASKED_GATHER, MaskedGatherIndexOr, lux, _ALL_VIRT) #undef HWY_RVV_MASKED_GATHER +template +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, const TFromD* base, + VFromD> indices) { + return MaskedGatherIndexOr(Zero(d), m, d, base, indices); +} + // ================================================== CONVERT // ------------------------------ PromoteTo @@ -1952,52 +2278,38 @@ HWY_API auto PromoteTo(Simd d, } // Unsigned to signed: cast for unsigned promote. -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { return BitCast(d, PromoteTo(RebindToUnsigned(), v)); } -template -HWY_API auto PromoteTo(Simd d, - VFromD> v) - -> VFromD { +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { const RebindToSigned di32; const Rebind du16; return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); @@ -2097,28 +2409,24 @@ HWY_API vuint8m2_t DemoteTo(Simd d, const vuint32m8_t v) { HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); } -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - return DemoteTo(d, DemoteTo(Simd(), v)); +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); } -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - return DemoteTo(d, DemoteTo(Simd(), v)); +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); } -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - return DemoteTo(d, DemoteTo(Simd(), v)); +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); } -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - return DemoteTo(d, DemoteTo(Simd(), v)); +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); } HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { @@ -2501,16 +2809,14 @@ HWY_API vint8m2_t DemoteTo(Simd d, const vint32m8_t v) { return DemoteTo(d, DemoteTo(Simd(), v)); } -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - return DemoteTo(d, DemoteTo(Simd(), v)); +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); } -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - return DemoteTo(d, DemoteTo(Simd(), v)); +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); } #undef HWY_RVV_DEMOTE @@ -2527,9 +2833,15 @@ HWY_API VFromD> DemoteTo( } #if HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C -HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_rod_f_f_w_f, _DEMOTE_VIRT) +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_f_f_w_f, _DEMOTE_VIRT) #endif -HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_rod_f_f_w_f, _DEMOTE_VIRT) +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_f_f_w_f, _DEMOTE_VIRT) + +namespace detail { +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteToF32WithRoundToOdd, + fncvt_rod_f_f_w_f, _DEMOTE_VIRT) +} // namespace detail + #undef HWY_RVV_DEMOTE_F // TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. @@ -2617,28 +2929,73 @@ HWY_API vfloat32m4_t DemoteTo(Simd d, const vuint64m8_t v) { return __riscv_vfncvt_f_xu_w_f32m4(v, Lanes(d)); } +// Narrows f32 bits to bf16 using round to even. // SEW is for the source so we can use _DEMOTE_VIRT. -#define HWY_RVV_DEMOTE_TO_SHR_16(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ - LMULH, SHIFT, MLEN, NAME, OP) \ +#ifdef HWY_RVV_AVOID_VXRM +#define HWY_RVV_DEMOTE_16_NEAREST_EVEN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, \ + LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + const auto round = \ + detail::AddS(detail::AndS(ShiftRight<16>(v), 1u), 0x7FFFu); \ + v = Add(v, round); \ + /* The default rounding mode appears to be RNU=0, which adds the LSB. */ \ + /* Prevent further rounding by clearing the bits we want to truncate. */ \ + v = detail::AndS(v, 0xFFFF0000u); \ + return __riscv_v##OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + +#else +#define HWY_RVV_DEMOTE_16_NEAREST_EVEN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, \ + LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ template \ HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ return __riscv_v##OP##CHAR##SEWH##LMULH( \ - v, 16, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); \ + v, 16, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNE, Lanes(d))); \ } +#endif // HWY_RVV_AVOID_VXRM namespace detail { -HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE_TO_SHR_16, DemoteToShr16, nclipu_wx_, - _DEMOTE_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE_16_NEAREST_EVEN, DemoteTo16NearestEven, + nclipu_wx_, _DEMOTE_VIRT) } -#undef HWY_RVV_DEMOTE_TO_SHR_16 +#undef HWY_RVV_DEMOTE_16_NEAREST_EVEN -template -HWY_API VFromD> DemoteTo( - Simd d, VFromD> v) { - const RebindToUnsigned du16; - const Rebind du32; - return BitCast(d, detail::DemoteToShr16(du16, BitCast(du32, v))); -} +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(DBF16 d, VFromD> v) { + const DFromV df; + const RebindToUnsigned du32; + const RebindToUnsigned du16; + // Consider an f32 mantissa with the upper 7 bits set, followed by a 1-bit + // and at least one other bit set. This will round to 0 and increment the + // exponent. If the exponent was already 0xFF (NaN), then the result is -inf; + // there no wraparound because nclipu saturates. Note that in this case, the + // input cannot have been inf because its mantissa bits are zero. To avoid + // converting NaN to inf, we canonicalize the NaN to prevent the rounding. + const decltype(v) canonicalized = + IfThenElse(Eq(v, v), v, BitCast(df, Set(du32, 0x7F800000))); + return BitCast( + d, detail::DemoteTo16NearestEven(du16, BitCast(du32, canonicalized))); +} + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df32; + return DemoteTo(df16, detail::DemoteToF32WithRoundToOdd(df32, v)); +} // ------------------------------ ConvertTo F @@ -2664,8 +3021,8 @@ HWY_API VFromD> DemoteTo( HWY_API HWY_RVV_V(uint, SEW, LMUL) ConvertTo( \ HWY_RVV_D(uint, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ return __riscv_vfcvt_rtz_xu_f_v_u##SEW##LMUL(v, Lanes(d)); \ - } \ -// API only requires f32 but we provide f64 for internal use. + } + HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) #undef HWY_RVV_CONVERT @@ -2678,6 +3035,32 @@ HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) HWY_RVV_FOREACH_F(HWY_RVV_NEAREST, _, _, _ALL) #undef HWY_RVV_NEAREST +template +HWY_API vint32mf2_t DemoteToNearestInt(Simd d, + const vfloat64m1_t v) { + return __riscv_vfncvt_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteToNearestInt(Simd d, + const vfloat64m1_t v) { + return __riscv_vfncvt_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteToNearestInt(Simd d, + const vfloat64m2_t v) { + return __riscv_vfncvt_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteToNearestInt(Simd d, + const vfloat64m4_t v) { + return __riscv_vfncvt_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteToNearestInt(Simd d, + const vfloat64m8_t v) { + return __riscv_vfncvt_x_f_w_i32m4(v, Lanes(d)); +} + // ================================================== COMBINE namespace detail { @@ -2704,7 +3087,7 @@ HWY_INLINE size_t LanesPerBlock(Simd d) { template HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { - using T = MakeUnsigned>; + using T = MakeUnsigned>; return AndS(iota0, static_cast(~(LanesPerBlock(d) - 1))); } @@ -2918,9 +3301,10 @@ HWY_RVV_FOREACH_B(HWY_RVV_SET_AT_OR_AFTER_FIRST, _, _) // ------------------------------ InsertLane -template -HWY_API V InsertLane(const V v, size_t i, TFromV t) { - const DFromV d; +// T template arg because TFromV might not match the hwy::float16_t argument. +template +HWY_API V InsertLane(const V v, size_t i, T t) { + const Rebind> d; const RebindToUnsigned du; // Iota0 is unsigned only using TU = TFromD; const auto is_i = detail::EqS(detail::Iota0(du), static_cast(i)); @@ -2928,9 +3312,9 @@ HWY_API V InsertLane(const V v, size_t i, TFromV t) { } // For 8-bit lanes, Iota0 might overflow. -template -HWY_API V InsertLane(const V v, size_t i, TFromV t) { - const DFromV d; +template +HWY_API V InsertLane(const V v, size_t i, T t) { + const Rebind> d; const auto zero = Zero(d); const auto one = Set(d, 1); const auto ge_i = Eq(detail::SlideUp(zero, one, i), one); @@ -2991,6 +3375,18 @@ HWY_API V DupOdd(const V v) { return OddEven(v, down); } +// ------------------------------ InterleaveEven (OddEven) +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return OddEven(detail::Slide1Up(b), a); +} + +// ------------------------------ InterleaveOdd (OddEven) +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return OddEven(b, detail::Slide1Down(a)); +} + // ------------------------------ OddEvenBlocks template HWY_API V OddEvenBlocks(const V a, const V b) { @@ -3034,9 +3430,6 @@ HWY_API VFromD> SetTableIndices(D d, const TI* idx) { return IndicesFromVec(d, LoadU(Rebind(), idx)); } -// TODO(janwas): avoid using this for 8-bit; wrap in detail namespace. -// For large 8-bit vectors, index overflow will lead to incorrect results. -// Reverse already uses TableLookupLanes16 to prevent this. #define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ @@ -3045,12 +3438,14 @@ HWY_API VFromD> SetTableIndices(D d, const TI* idx) { HWY_RVV_AVL(SEW, SHIFT)); \ } +// TableLookupLanes is supported for all types, but beware that indices are +// likely to wrap around for 8-bit lanes. When using TableLookupLanes inside +// this file, ensure that it is safe or use TableLookupLanes16 instead. HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) #undef HWY_RVV_TABLE namespace detail { -// Used by I8/U8 Reverse #define HWY_RVV_TABLE16(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ SHIFT, MLEN, NAME, OP) \ HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ @@ -3122,6 +3517,67 @@ HWY_API VFromD Reverse(D /* tag */, VFromD v) { return TableLookupLanes(v, idx); } +// ------------------------------ ResizeBitCast + +// Extends or truncates a vector to match the given d. +namespace detail { + +template +HWY_INLINE VFromD ChangeLMUL(D /* d */, VFromD v) { + return v; +} + +// Sanity check: when calling ChangeLMUL, the caller (ResizeBitCast) already +// BitCast to the same lane type. Note that V may use the native lane type for +// f16, so convert D to that before checking. +#define HWY_RVV_IF_SAME_T_DV(D, V) \ + hwy::EnableIf>, TFromV>()>* = nullptr + +// LMUL of VFromD < LMUL of V: need to truncate v +template >, DFromV().Pow2() - 1)> +HWY_INLINE VFromD ChangeLMUL(D d, V v) { + const DFromV d_from; + const Half dh_from; + static_assert( + DFromV>().Pow2() < DFromV().Pow2(), + "The LMUL of VFromD must be less than the LMUL of V"); + static_assert( + DFromV>().Pow2() <= DFromV>().Pow2(), + "The LMUL of VFromD must be less than or equal to the LMUL of " + "VFromD"); + return ChangeLMUL(d, Trunc(v)); +} + +// LMUL of VFromD > LMUL of V: need to extend v +template >, DFromV().Pow2())> +HWY_INLINE VFromD ChangeLMUL(D d, V v) { + const DFromV d_from; + const Twice dt_from; + static_assert(DFromV>().Pow2() > DFromV().Pow2(), + "The LMUL of VFromD must be greater than " + "the LMUL of V"); + static_assert( + DFromV>().Pow2() >= DFromV>().Pow2(), + "The LMUL of VFromD must be greater than or equal to the LMUL of " + "VFromD"); + return ChangeLMUL(d, Ext(dt_from, v)); +} + +#undef HWY_RVV_IF_SAME_T_DV + +} // namespace detail + +template +HWY_API VFromD ResizeBitCast(DTo /*dto*/, VFrom v) { + const DFromV d_from; + const Repartition du8_from; + const DFromV> d_to; + const Repartition du8_to; + return BitCast(d_to, detail::ChangeLMUL(du8_to, BitCast(du8_from, v))); +} + // ------------------------------ Reverse2 (RotateRight, OddEven) // Per-target flags to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. @@ -3307,7 +3763,7 @@ template HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d, TFromD* HWY_RESTRICT unaligned) { const size_t count = CountTrue(d, mask); - detail::StoreN(count, Compress(v, mask), d, unaligned); + StoreN(Compress(v, mask), d, unaligned, count); return count; } @@ -3409,6 +3865,9 @@ HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { return Combine(d, LowerHalf(dh, hi_even), LowerHalf(dh, lo_even)); } +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + // ================================================== BLOCKWISE // ------------------------------ CombineShiftRightBytes @@ -3483,50 +3942,6 @@ HWY_API V Shuffle0123(const V v) { // ------------------------------ TableLookupBytes -// Extends or truncates a vector to match the given d. -namespace detail { - -template -HWY_INLINE VFromD ChangeLMUL(D /* d */, VFromD v) { - return v; -} - -// LMUL of VFromD < LMUL of V: need to truncate v -template , TFromV>()>* = nullptr, - HWY_IF_POW2_LE_D(DFromV>, DFromV().Pow2() - 1)> -HWY_INLINE VFromD ChangeLMUL(D d, V v) { - const DFromV d_from; - const Half dh_from; - static_assert( - DFromV>().Pow2() < DFromV().Pow2(), - "The LMUL of VFromD must be less than the LMUL of V"); - static_assert( - DFromV>().Pow2() <= DFromV>().Pow2(), - "The LMUL of VFromD must be less than or equal to the LMUL of " - "VFromD"); - return ChangeLMUL(d, Trunc(v)); -} - -// LMUL of VFromD > LMUL of V: need to extend v -template , TFromV>()>* = nullptr, - HWY_IF_POW2_GT_D(DFromV>, DFromV().Pow2())> -HWY_INLINE VFromD ChangeLMUL(D d, V v) { - const DFromV d_from; - const Twice dt_from; - static_assert(DFromV>().Pow2() > DFromV().Pow2(), - "The LMUL of VFromD must be greater than " - "the LMUL of V"); - static_assert( - DFromV>().Pow2() >= DFromV>().Pow2(), - "The LMUL of VFromD must be greater than or equal to the LMUL of " - "VFromD"); - return ChangeLMUL(d, Ext(dt_from, v)); -} - -} // namespace detail - template HWY_API VI TableLookupBytes(const VT vt, const VI vi) { const DFromV dt; // T=table, I=index. @@ -3563,7 +3978,8 @@ HWY_API VI TableLookupBytesOr0(const VT vt, const VI idx) { // ------------------------------ TwoTablesLookupLanes -// TODO(janwas): special-case 8-bit lanes to safely handle VL >= 256 +// WARNING: 8-bit lanes may lead to unexpected results because idx is the same +// size and may overflow. template HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, VFromD> idx) { @@ -3597,11 +4013,50 @@ HWY_API V TwoTablesLookupLanes(V a, V b, } // ------------------------------ Broadcast -template + +// 8-bit requires 16-bit tables. +template , HWY_IF_T_SIZE_D(D, 1), + HWY_IF_POW2_LE_D(D, 2)> HWY_API V Broadcast(const V v) { - const DFromV d; - const RebindToUnsigned du; + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const Rebind du16; + VFromD idx = + detail::OffsetsOf128BitBlocks(d, detail::Iota0(du16)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return detail::TableLookupLanes16(v, idx); +} + +// 8-bit and max LMUL: split into halves. +template , HWY_IF_T_SIZE_D(D, 1), + HWY_IF_POW2_GT_D(D, 2)> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const Half dh; + using VH = VFromD; + const Rebind du16; + VFromD idx = + detail::OffsetsOf128BitBlocks(d, detail::Iota0(du16)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + const VH lo = detail::TableLookupLanes16(LowerHalf(dh, v), idx); + const VH hi = detail::TableLookupLanes16(UpperHalf(dh, v), idx); + return Combine(d, hi, lo); +} + +template , + HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 2) | (1 << 4) | (1 << 8))> +HWY_API V Broadcast(const V v) { + const D d; HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const RebindToUnsigned du; auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(du)); if (kLane != 0) { idx = detail::AddS(idx, kLane); @@ -3778,20 +4233,194 @@ HWY_API V ShiftRightBytes(const D d, const V v) { return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); } -// ------------------------------ InterleaveLower +// ------------------------------ InterleaveWholeLower +#ifdef HWY_NATIVE_INTERLEAVE_WHOLE +#undef HWY_NATIVE_INTERLEAVE_WHOLE +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE +#endif + +namespace detail { +// Returns double-length vector with interleaved lanes. +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const RebindToUnsigned du; + using TW = MakeWide>; + const Rebind> dw; + const Half duh; // cast inputs to unsigned so we zero-extend -template + const VFromD aw = PromoteTo(dw, BitCast(duh, a)); + const VFromD bw = PromoteTo(dw, BitCast(duh, b)); + return BitCast(d, Or(aw, BitCast(dw, detail::Slide1Up(BitCast(du, bw))))); +} +// 64-bit: cannot PromoteTo, but can Ext. +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const RebindToUnsigned du; + const auto idx = ShiftRight<1>(detail::Iota0(du)); + return OddEven(TableLookupLanes(detail::Ext(d, b), idx), + TableLookupLanes(detail::Ext(d, a), idx)); +} +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const Half dh; + const Half dq; + const VFromD i0 = + InterleaveWhole(dh, LowerHalf(dq, a), LowerHalf(dq, b)); + const VFromD i1 = + InterleaveWhole(dh, UpperHalf(dq, a), UpperHalf(dq, b)); + return Combine(d, i1, i0); +} + +} // namespace detail + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + const detail::AdjustSimdTagToMinVecPow2> dw; + const RepartitionToNarrow du_src; + + const VFromD aw = + ResizeBitCast(d, PromoteLowerTo(dw, ResizeBitCast(du_src, a))); + const VFromD bw = + ResizeBitCast(d, PromoteLowerTo(dw, ResizeBitCast(du_src, b))); + return Or(aw, detail::Slide1Up(bw)); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + const auto idx = ShiftRight<1>(detail::Iota0(du)); + return OddEven(TableLookupLanes(b, idx), TableLookupLanes(a, idx)); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // Use Lanes(d) / 2 instead of Lanes(Half()) as Lanes(Half()) can only + // be called if (d.Pow2() >= -2 && d.Pow2() == DFromV>().Pow2()) is + // true and and as the results of InterleaveWholeUpper are + // implementation-defined if Lanes(d) is less than 2. + const size_t half_N = Lanes(d) / 2; + return InterleaveWholeLower(d, detail::SlideDown(a, half_N), + detail::SlideDown(b, half_N)); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // Use Lanes(d) / 2 instead of Lanes(Half()) as Lanes(Half()) can only + // be called if (d.Pow2() >= -2 && d.Pow2() == DFromV>().Pow2()) is + // true and as the results of InterleaveWholeUpper are implementation-defined + // if Lanes(d) is less than 2. + const size_t half_N = Lanes(d) / 2; + const RebindToUnsigned du; + const auto idx = detail::AddS(ShiftRight<1>(detail::Iota0(du)), + static_cast(half_N)); + return OddEven(TableLookupLanes(b, idx), TableLookupLanes(a, idx)); +} + +// ------------------------------ InterleaveLower (InterleaveWholeLower) + +namespace detail { + +// Definitely at least 128 bit: match x86 semantics (independent blocks). Using +// InterleaveWhole and 64-bit Compress avoids 8-bit overflow. +template +HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const Twice dt; + const RebindToUnsigned dt_u; + const VFromD interleaved = detail::InterleaveWhole(dt, a, b); + // Keep only even 128-bit blocks. This is faster than u64 ConcatEven + // because we only have a single vector. + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromD)); + const VFromD idx_block = + ShiftRight(detail::Iota0(dt_u)); + const MFromD is_even = + detail::EqS(detail::AndS(idx_block, 1), 0); + return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_even))); +} +template +HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) { + const Half dh; + const VFromD i0 = + InterleaveLowerBlocks(dh, LowerHalf(dh, a), LowerHalf(dh, b)); + const VFromD i1 = + InterleaveLowerBlocks(dh, UpperHalf(dh, a), UpperHalf(dh, b)); + return Combine(d, i1, i0); +} + +// As above, for the upper half of blocks. +template +HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const Twice dt; + const RebindToUnsigned dt_u; + const VFromD interleaved = detail::InterleaveWhole(dt, a, b); + // Keep only odd 128-bit blocks. This is faster than u64 ConcatEven + // because we only have a single vector. + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromD)); + const VFromD idx_block = + ShiftRight(detail::Iota0(dt_u)); + const MFromD is_odd = + detail::EqS(detail::AndS(idx_block, 1), 1); + return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_odd))); +} +template +HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) { + const Half dh; + const VFromD i0 = + InterleaveUpperBlocks(dh, LowerHalf(dh, a), LowerHalf(dh, b)); + const VFromD i1 = + InterleaveUpperBlocks(dh, UpperHalf(dh, a), UpperHalf(dh, b)); + return Combine(d, i1, i0); +} + +// RVV vectors are at least 128 bit when there is no fractional LMUL nor cap. +// Used by functions with per-block behavior such as InterleaveLower. +template +constexpr bool IsGE128(Simd /* d */) { + return N * sizeof(T) >= 16 && kPow2 >= 0; +} + +// Definitely less than 128-bit only if there is a small cap; fractional LMUL +// might not be enough if vectors are large. +template +constexpr bool IsLT128(Simd /* d */) { + return N * sizeof(T) < 16; +} + +} // namespace detail + +#define HWY_RVV_IF_GE128_D(D) hwy::EnableIf* = nullptr +#define HWY_RVV_IF_LT128_D(D) hwy::EnableIf* = nullptr +#define HWY_RVV_IF_CAN128_D(D) \ + hwy::EnableIf* = nullptr + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + return detail::InterleaveLowerBlocks(d, a, b); +} + +// Single block: interleave without extra Compress. +template HWY_API V InterleaveLower(D d, const V a, const V b) { static_assert(IsSame, TFromV>(), "D/V mismatch"); - const RebindToUnsigned du; - using TU = TFromD; - const auto i = detail::Iota0(du); - const auto idx_mod = ShiftRight<1>( - detail::AndS(i, static_cast(detail::LanesPerBlock(du) - 1))); - const auto idx = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); - const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); - return IfThenElse(is_even, TableLookupLanes(a, idx), - TableLookupLanes(b, idx)); + return InterleaveWholeLower(d, a, b); +} + +// Could be either; branch at runtime. +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + if (Lanes(d) * sizeof(TFromD) <= 16) { + return InterleaveWholeLower(d, a, b); + } + // Fractional LMUL: use LMUL=1 to ensure we can cast to u64. + const ScalableTag, HWY_MAX(d.Pow2(), 0)> d1; + return ResizeBitCast(d, detail::InterleaveLowerBlocks( + d1, ResizeBitCast(d1, a), ResizeBitCast(d1, b))); } template @@ -3799,21 +4428,30 @@ HWY_API V InterleaveLower(const V a, const V b) { return InterleaveLower(DFromV(), a, b); } -// ------------------------------ InterleaveUpper +// ------------------------------ InterleaveUpper (Compress) -template -HWY_API V InterleaveUpper(const D d, const V a, const V b) { +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + return detail::InterleaveUpperBlocks(d, a, b); +} + +// Single block: interleave without extra Compress. +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { static_assert(IsSame, TFromV>(), "D/V mismatch"); - const RebindToUnsigned du; - using TU = TFromD; - const size_t lpb = detail::LanesPerBlock(du); - const auto i = detail::Iota0(du); - const auto idx_mod = ShiftRight<1>(detail::AndS(i, static_cast(lpb - 1))); - const auto idx_lower = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); - const auto idx = detail::AddS(idx_lower, static_cast(lpb / 2)); - const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); - return IfThenElse(is_even, TableLookupLanes(a, idx), - TableLookupLanes(b, idx)); + return InterleaveWholeUpper(d, a, b); +} + +// Could be either; branch at runtime. +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + if (Lanes(d) * sizeof(TFromD) <= 16) { + return InterleaveWholeUpper(d, a, b); + } + // Fractional LMUL: use LMUL=1 to ensure we can cast to u64. + const ScalableTag, HWY_MAX(d.Pow2(), 0)> d1; + return ResizeBitCast(d, detail::InterleaveUpperBlocks( + d1, ResizeBitCast(d1, a), ResizeBitCast(d1, b))); } // ------------------------------ ZipLower @@ -3840,67 +4478,98 @@ HWY_API VFromD ZipUpper(DW dw, V a, V b) { // ================================================== REDUCE -// vector = f(vector, zero_m1) +// We have ReduceSum, generic_ops-inl.h defines SumOfLanes via Set. +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// scalar = f(vector, zero_m1) #define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ MLEN, NAME, OP) \ - template \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ - NAME(D d, HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, m1) v0) { \ - return Set(d, \ - GetLane(__riscv_v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ - v, v0, Lanes(d)))); \ + template \ + HWY_API HWY_RVV_T(BASE, SEW) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_V(BASE, SEW, m1) v0) { \ + return GetLane(__riscv_v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v, v0, Lanes(d))); \ } -// ------------------------------ SumOfLanes +// detail::RedSum, detail::RedMin, and detail::RedMax is more efficient +// for N=4 I8/U8 reductions on RVV than the default implementations of the +// the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in generic_ops-inl.h +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) hwy::EnableIf* = nullptr + +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif + +// ------------------------------ ReduceSum namespace detail { -HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL) -HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL_VIRT) } // namespace detail -template -HWY_API VFromD SumOfLanes(D d, const VFromD v) { +template +HWY_API TFromD ReduceSum(D d, const VFromD v) { const auto v0 = Zero(ScalableTag>()); // always m1 return detail::RedSum(d, v, v0); } -template -HWY_API TFromD ReduceSum(D d, const VFromD v) { - return GetLane(SumOfLanes(d, v)); -} - -// ------------------------------ MinOfLanes +// ------------------------------ ReduceMin namespace detail { -HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL) -HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL) -HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL_VIRT) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL_VIRT) } // namespace detail -template -HWY_API VFromD MinOfLanes(D d, const VFromD v) { - using T = TFromD; +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMin(D d, const VFromD v) { const ScalableTag d1; // always m1 - const auto neutral = Set(d1, HighestValue()); - return detail::RedMin(d, v, neutral); + return detail::RedMin(d, v, Set(d1, HighestValue())); } -// ------------------------------ MaxOfLanes +// ------------------------------ ReduceMax namespace detail { -HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL) -HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL) -HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL_VIRT) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL_VIRT) } // namespace detail -template -HWY_API VFromD MaxOfLanes(D d, const VFromD v) { - using T = TFromD; +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMax(D d, const VFromD v) { const ScalableTag d1; // always m1 - const auto neutral = Set(d1, LowestValue()); - return detail::RedMax(d, v, neutral); + return detail::RedMax(d, v, Set(d1, LowestValue())); } #undef HWY_RVV_REDUCE +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + // ================================================== Ops with dependencies // ------------------------------ LoadInterleaved2 @@ -4116,7 +4785,7 @@ HWY_RVV_FOREACH(HWY_RVV_STORE4, StoreInterleaved4, sseg4, _LE2_VIRT) #else // !HWY_HAVE_TUPLE -template > +template , HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1) { const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] @@ -4139,7 +4808,7 @@ HWY_RVV_FOREACH(HWY_RVV_LOAD_STRIDED, LoadStrided, lse, _ALL_VIRT) #undef HWY_RVV_LOAD_STRIDED } // namespace detail -template > +template , HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2) { // Offsets are bytes, and this is not documented. @@ -4148,7 +4817,7 @@ HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, v2 = detail::LoadStrided(d, unaligned + 2, 3 * sizeof(T)); } -template > +template , HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2, VFromD& v3) { @@ -4161,7 +4830,7 @@ HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, // Not 64-bit / max LMUL: interleave via promote, slide, OddEven. template , HWY_IF_NOT_T_SIZE_D(D, 8), - HWY_IF_POW2_LE_D(D, 2)> + HWY_IF_POW2_LE_D(D, 2), HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; @@ -4176,7 +4845,7 @@ HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, // Can promote, max LMUL: two half-length template , HWY_IF_NOT_T_SIZE_D(D, 8), - HWY_IF_POW2_GT_D(D, 2)> + HWY_IF_POW2_GT_D(D, 2), HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, T* HWY_RESTRICT unaligned) { const Half dh; @@ -4200,7 +4869,8 @@ HWY_RVV_FOREACH(HWY_RVV_STORE_STRIDED, StoreStrided, sse, _ALL_VIRT) } // namespace detail // 64-bit: strided -template , HWY_IF_T_SIZE_D(D, 8)> +template , HWY_IF_T_SIZE_D(D, 8), + HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, T* HWY_RESTRICT unaligned) { // Offsets are bytes, and this is not documented. @@ -4208,7 +4878,7 @@ HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, detail::StoreStrided(v1, d, unaligned + 1, 2 * sizeof(T)); } -template > +template , HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, T* HWY_RESTRICT unaligned) { // Offsets are bytes, and this is not documented. @@ -4217,7 +4887,7 @@ HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, detail::StoreStrided(v2, d, unaligned + 2, 3 * sizeof(T)); } -template > +template , HWY_RVV_IF_NOT_EMULATED_D(D)> HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, VFromD v3, D d, T* HWY_RESTRICT unaligned) { // Offsets are bytes, and this is not documented. @@ -4229,15 +4899,90 @@ HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, #endif // HWY_HAVE_TUPLE -// ------------------------------ ResizeBitCast +// Rely on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_RVV_IF_EMULATED_D. -template -HWY_API VFromD ResizeBitCast(D /*d*/, FromV v) { - const DFromV d_from; - const Repartition du8_from; - const DFromV> d_to; - const Repartition du8_to; - return BitCast(d_to, detail::ChangeLMUL(du8_to, BitCast(du8_from, v))); +// ------------------------------ Dup128VecFromValues (ResizeBitCast) + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD /*t1*/) { + return Set(d, t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { + const auto even_lanes = Set(d, t0); +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(BitCastScalar(t0) == + BitCastScalar(t1)) && + (BitCastScalar(t0) == BitCastScalar(t1))) { + return even_lanes; + } +#endif + + const auto odd_lanes = Set(d, t1); + return OddEven(odd_lanes, even_lanes); +} + +namespace detail { + +#pragma pack(push, 1) + +template +struct alignas(8) Vec64ValsWrapper { + static_assert(sizeof(T) >= 1, "sizeof(T) >= 1 must be true"); + static_assert(sizeof(T) <= 8, "sizeof(T) <= 8 must be true"); + T vals[8 / sizeof(T)]; +}; + +#pragma pack(pop) + +} // namespace detail + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, Dup128VecFromValues( + du64, + BitCastScalar(detail::Vec64ValsWrapper>{ + {t0, t1, t2, t3, t4, t5, t6, t7}}), + BitCastScalar(detail::Vec64ValsWrapper>{ + {t8, t9, t10, t11, t12, t13, t14, t15}}))); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, Dup128VecFromValues( + du64, + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1, t2, t3}}), + BitCastScalar( + detail::Vec64ValsWrapper>{{t4, t5, t6, t7}}))); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, + Dup128VecFromValues(du64, + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}), + BitCastScalar( + detail::Vec64ValsWrapper>{{t2, t3}}))); } // ------------------------------ PopulationCount (ShiftRight) @@ -4366,34 +5111,287 @@ HWY_API MFromD FirstN(const D d, const size_t n) { return Eq(detail::SlideUp(one, zero, n), one); } -// ------------------------------ Neg (Sub) +// ------------------------------ LowerHalfOfMask/UpperHalfOfMask -template -HWY_API V Neg(const V v) { - return detail::ReverseSubS(v, 0); +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// Target-specific implementations of LowerHalfOfMask, UpperHalfOfMask, +// CombineMasks, OrderedDemote2MasksTo, and Dup128MaskFromMaskBits are possible +// on RVV if the __riscv_vreinterpret_v_b*_u8m1 and +// __riscv_vreinterpret_v_u8m1_b* intrinsics are available. + +// The __riscv_vreinterpret_v_b*_u8m1 and __riscv_vreinterpret_v_u8m1_b* +// intrinsics available with Clang 17 and later and GCC 14 and later. + +namespace detail { + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool1_t m) { + return __riscv_vreinterpret_v_b1_u8m1(m); } -// vector = f(vector), but argument is repeated -#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ - SHIFT, MLEN, NAME, OP) \ - HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ - return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, v, \ - HWY_RVV_AVL(SEW, SHIFT)); \ +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool2_t m) { + return __riscv_vreinterpret_v_b2_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool4_t m) { + return __riscv_vreinterpret_v_b4_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool8_t m) { + return __riscv_vreinterpret_v_b8_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool16_t m) { + return __riscv_vreinterpret_v_b16_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool32_t m) { + return __riscv_vreinterpret_v_b32_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool64_t m) { + return __riscv_vreinterpret_v_b64_u8m1(m); +} + +template , vbool1_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b1(v); +} + +template , vbool2_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b2(v); +} + +template , vbool4_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b4(v); +} + +template , vbool8_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b8(v); +} + +template , vbool16_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b16(v); +} + +template , vbool32_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b32(v); +} + +template , vbool64_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b64(v); +} + +} // namespace detail + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API MFromD LowerHalfOfMask(D d, MFromD> m) { + return detail::U8MaskBitsVecToMask(d, detail::MaskToU8MaskBitsVec(m)); +} + +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +template +HWY_API MFromD UpperHalfOfMask(D d, MFromD> m) { + const size_t N = Lanes(d); + + vuint8m1_t mask_bits = detail::MaskToU8MaskBitsVec(m); + mask_bits = ShiftRightSame(mask_bits, static_cast(N & 7)); + if (HWY_MAX_LANES_D(D) >= 8) { + mask_bits = SlideDownLanes(ScalableTag(), mask_bits, N / 8); } -HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + return detail::U8MaskBitsVecToMask(d, mask_bits); +} -#if !HWY_HAVE_FLOAT16 +// ------------------------------ CombineMasks -template )> // float16_t -HWY_API V Neg(V v) { - const DFromV d; - const RebindToUnsigned du; - using TU = TFromD; - return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +template +HWY_API MFromD CombineMasks(D d, MFromD> hi, MFromD> lo) { + const Half dh; + const size_t half_N = Lanes(dh); + + const auto ext_lo_mask = + And(detail::U8MaskBitsVecToMask(d, detail::MaskToU8MaskBitsVec(lo)), + FirstN(d, half_N)); + vuint8m1_t hi_mask_bits = detail::MaskToU8MaskBitsVec(hi); + hi_mask_bits = ShiftLeftSame(hi_mask_bits, static_cast(half_N & 7)); + if (HWY_MAX_LANES_D(D) >= 8) { + hi_mask_bits = + SlideUpLanes(ScalableTag(), hi_mask_bits, half_N / 8); + } + + return Or(ext_lo_mask, detail::U8MaskBitsVecToMask(d, hi_mask_bits)); } -#endif // !HWY_HAVE_FLOAT16 +// ------------------------------ OrderedDemote2MasksTo + +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +template ) / 2), + class DTo_2 = Repartition, DFrom>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD OrderedDemote2MasksTo(DTo d_to, DFrom /*d_from*/, + MFromD a, MFromD b) { + return CombineMasks(d_to, b, a); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Dup128MaskFromMaskBits + +namespace detail { +// Even though this is only used after checking if (kN < X), this helper +// function prevents "shift count exceeded" errors. +template +constexpr unsigned MaxMaskBits() { + return (1u << kN) - 1; +} +template +constexpr unsigned MaxMaskBits() { + return ~0u; +} + +template +constexpr int SufficientPow2ForMask() { + return HWY_MAX( + D().Pow2() - 3 - static_cast(FloorLog2(sizeof(TFromD))), -3); +} +} // namespace detail + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + return detail::U8MaskBitsVecToMask( + d, Set(ScalableTag(), static_cast(mask_bits))); +#else + const RebindToUnsigned du8; + const detail::AdjustSimdTagToMinVecPow2> + du64; + + const auto bytes = ResizeBitCast( + du8, detail::AndS( + ResizeBitCast(du64, Set(du8, static_cast(mask_bits))), + uint64_t{0x8040201008040201u})); + return detail::NeS(bytes, uint8_t{0}); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + const ScalableTag()> du16; + // There are exactly 16 mask bits for 128 vector bits of 8-bit lanes. + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL( + ScalableTag(), + BitCast(du8, Set(du16, static_cast(mask_bits))))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du8; + const Repartition du16; + const detail::AdjustSimdTagToMinVecPow2> + du64; + + // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, + // and then bitcast the replicated mask_bits to a u8 vector + const auto bytes = BitCast(du8, Set(du16, static_cast(mask_bits))); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const auto rep8 = TableLookupLanes(bytes, ShiftRight<3>(detail::Iota0(du8))); + + const auto masked_out_rep8 = ResizeBitCast( + du8, + detail::AndS(ResizeBitCast(du64, rep8), uint64_t{0x8040201008040201u})); + return detail::NeS(masked_out_rep8, uint8_t{0}); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + // There are exactly 8 mask bits for 128 vector bits of 16-bit lanes. + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits)))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = + Shl(Set(du, uint16_t{1}), detail::AndS(detail::Iota0(du), 7)); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 4) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits * 0x11)))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = Dup128VecFromValues(du, 1, 2, 4, 8); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 2) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits * 0x55)))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = Dup128VecFromValues(du, 1, 2); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} // ------------------------------ Abs (Max, Neg) @@ -4452,23 +5450,99 @@ HWY_API V Trunc(const V v) { } // ------------------------------ Ceil +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) +namespace detail { +#define HWY_RVV_CEIL_INT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) CeilInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL##_rm(v, __RISCV_FRM_RUP, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_CEIL_INT, _, _, _ALL) +#undef HWY_RVV_CEIL_INT + +} // namespace detail + template HWY_API V Ceil(const V v) { - asm volatile("fsrm %0" ::"r"(detail::kUp)); - const auto ret = Round(v); - asm volatile("fsrm %0" ::"r"(detail::kNear)); - return ret; + const DFromV df; + + const auto integer = detail::CeilInt(v); + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); } +#else // GCC 13 or earlier or Clang 16 or earlier + +template +HWY_API V Ceil(const V v) { + const DFromV df; + const RebindToSigned di; + + using T = TFromD; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto pos1 = + IfThenElseZero(Lt(int_f, v), Set(df, ConvertScalarTo(1.0))); + + return IfThenElse(detail::UseInt(v), Add(int_f, pos1), v); +} + +#endif // (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || + // (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) + // ------------------------------ Floor +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) +namespace detail { +#define HWY_RVV_FLOOR_INT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) FloorInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL##_rm(v, __RISCV_FRM_RDN, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_FLOOR_INT, _, _, _ALL) +#undef HWY_RVV_FLOOR_INT + +} // namespace detail + template HWY_API V Floor(const V v) { - asm volatile("fsrm %0" ::"r"(detail::kDown)); - const auto ret = Round(v); - asm volatile("fsrm %0" ::"r"(detail::kNear)); - return ret; + const DFromV df; + + const auto integer = detail::FloorInt(v); + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); } +#else // GCC 13 or earlier or Clang 16 or earlier + +template +HWY_API V Floor(const V v) { + const DFromV df; + const RebindToSigned di; + + using T = TFromD; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = + IfThenElseZero(Gt(int_f, v), Set(df, ConvertScalarTo(-1.0))); + + return IfThenElse(detail::UseInt(v), Add(int_f, neg1), v); +} + +#endif // (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || + // (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) + // ------------------------------ Floating-point classification (Ne) // vfclass does not help because it would require 3 instructions (to AND and @@ -4479,6 +5553,14 @@ HWY_API MFromD> IsNaN(const V v) { return Ne(v, v); } +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +// We use a fused Set/comparison for IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + template > HWY_API MFromD IsInf(const V v) { const D d; @@ -4507,22 +5589,76 @@ HWY_API MFromD IsFinite(const V v) { // ------------------------------ Iota (ConvertTo) -template -HWY_API VFromD Iota(const D d, TFromD first) { - return detail::AddS(detail::Iota0(d), first); +template +HWY_API VFromD Iota(const D d, T2 first) { + return detail::AddS(detail::Iota0(d), static_cast>(first)); } -template -HWY_API VFromD Iota(const D d, TFromD first) { +template +HWY_API VFromD Iota(const D d, T2 first) { const RebindToUnsigned du; - return detail::AddS(BitCast(d, detail::Iota0(du)), first); + return detail::AddS(BitCast(d, detail::Iota0(du)), + static_cast>(first)); } -template -HWY_API VFromD Iota(const D d, TFromD first) { +template +HWY_API VFromD Iota(const D d, T2 first) { const RebindToUnsigned du; const RebindToSigned di; - return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), first); + return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), + ConvertScalarTo>(first)); +} + +// ------------------------------ BitShuffle (PromoteTo, Rol, SumsOf8) + +// Native implementation required to avoid 8-bit wraparound on long vectors. +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +// Cannot handle LMUL=8 because we promote indices. +template ), class D64 = DFromV, + HWY_IF_UI64_D(D64), HWY_IF_POW2_LE_D(D64, 2)> +HWY_API V64 BitShuffle(V64 values, VI idx) { + const RebindToUnsigned du64; + const Repartition du8; + const Rebind du16; + using VU8 = VFromD; + using VU16 = VFromD; + // For each 16-bit (to avoid wraparound for long vectors) index of an output + // byte: offset of the u64 lane to which it belongs. + const VU16 byte_offsets = + detail::AndS(detail::Iota0(du16), static_cast(~7u)); + // idx is for a bit; shifting makes that bytes. Promote so we can add + // byte_offsets, then we have the u8 lane index within the whole vector. + const VU16 idx16 = + Add(byte_offsets, PromoteTo(du16, ShiftRight<3>(BitCast(du8, idx)))); + const VU8 bytes = detail::TableLookupLanes16(BitCast(du8, values), idx16); + + // We want to shift right by idx & 7 to extract the desired bit in `bytes`, + // and left by iota & 7 to put it in the correct output bit. To correctly + // handle shift counts from -7 to 7, we rotate (unfortunately not natively + // supported on RVV). + const VU8 rotate_left_bits = Sub(detail::Iota0(du8), BitCast(du8, idx)); + const VU8 extracted_bits_mask = + BitCast(du8, Set(du64, static_cast(0x8040201008040201u))); + const VU8 extracted_bits = + And(Rol(bytes, rotate_left_bits), extracted_bits_mask); + // Combine bit-sliced (one bit per byte) into one 64-bit sum. + return BitCast(D64(), SumsOf8(extracted_bits)); +} + +template ), class D64 = DFromV, + HWY_IF_UI64_D(D64), HWY_IF_POW2_GT_D(D64, 2)> +HWY_API V64 BitShuffle(V64 values, VI idx) { + const Half dh; + const Half> dih; + using V64H = VFromD; + const V64H r0 = BitShuffle(LowerHalf(dh, values), LowerHalf(dih, idx)); + const V64H r1 = BitShuffle(UpperHalf(dh, values), UpperHalf(dih, idx)); + return Combine(D64(), r1, r0); } // ------------------------------ MulEven/Odd (Mul, OddEven) @@ -4531,7 +5667,7 @@ template , class DW = RepartitionToWide> HWY_API VFromD MulEven(const V a, const V b) { const auto lo = Mul(a, b); - const auto hi = detail::MulHigh(a, b); + const auto hi = MulHigh(a, b); return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo)); } @@ -4539,7 +5675,7 @@ template , class DW = RepartitionToWide> HWY_API VFromD MulOdd(const V a, const V b) { const auto lo = Mul(a, b); - const auto hi = detail::MulHigh(a, b); + const auto hi = MulHigh(a, b); return BitCast(DW(), OddEven(hi, detail::Slide1Down(lo))); } @@ -4547,28 +5683,34 @@ HWY_API VFromD MulOdd(const V a, const V b) { template HWY_INLINE V MulEven(const V a, const V b) { const auto lo = Mul(a, b); - const auto hi = detail::MulHigh(a, b); + const auto hi = MulHigh(a, b); return OddEven(detail::Slide1Up(hi), lo); } template HWY_INLINE V MulOdd(const V a, const V b) { const auto lo = Mul(a, b); - const auto hi = detail::MulHigh(a, b); + const auto hi = MulHigh(a, b); return OddEven(hi, detail::Slide1Down(lo)); } // ------------------------------ ReorderDemote2To (OddEven, Combine) -template -HWY_API VFromD> ReorderDemote2To( - Simd dbf16, - VFromD> a, - VFromD> b) { +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { const RebindToUnsigned du16; + const Half du16_half; const RebindToUnsigned> du32; - const VFromD b_in_even = ShiftRight<16>(BitCast(du32, b)); - return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); + const VFromD a_in_even = PromoteTo( + du32, detail::DemoteTo16NearestEven(du16_half, BitCast(du32, a))); + const VFromD b_in_even = PromoteTo( + du32, detail::DemoteTo16NearestEven(du16_half, BitCast(du32, b))); + // Equivalent to InterleaveEven, but because the upper 16 bits are zero, we + // can OR instead of OddEven. + const VFromD a_in_odd = + detail::Slide1Up(BitCast(du16, a_in_even)); + return BitCast(dbf16, Or(a_in_odd, BitCast(du16, b_in_even))); } // If LMUL is not the max, Combine first to avoid another DemoteTo. @@ -4618,8 +5760,8 @@ HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { } // If LMUL is not the max, Combine first to avoid another DemoteTo. -template ), +template ), class V2 = VFromD, DN>>, hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { @@ -4629,8 +5771,8 @@ HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { } // Max LMUL: must DemoteTo first, then Combine. -template ), +template ), class V2 = VFromD, DN>>, hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { @@ -4653,68 +5795,26 @@ HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { // ------------------------------ WidenMulPairwiseAdd -template >> -HWY_API VFromD WidenMulPairwiseAdd(D32 df32, V16 a, V16 b) { - const RebindToUnsigned du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 - // Using shift/and instead of Zip leads to the odd/even order that - // RearrangeToOddPlusEven prefers. - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), - Mul(BitCast(df32, ao), BitCast(df32, bo))); -} - -template -HWY_API VFromD WidenMulPairwiseAdd(D d32, VI16 a, VI16 b) { - using VI32 = VFromD; - // Manual sign extension requires two shifts for even lanes. - const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); - const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); - const VI32 ao = ShiftRight<16>(BitCast(d32, a)); - const VI32 bo = ShiftRight<16>(BitCast(d32, b)); - return Add(Mul(ae, be), Mul(ao, bo)); -} - -template -HWY_API VFromD WidenMulPairwiseAdd(D du32, VI16 a, VI16 b) { - using VU32 = VFromD; - // Manual sign extension requires two shifts for even lanes. - const VU32 ae = detail::AndS(BitCast(du32, a), uint32_t{0x0000FFFFu}); - const VU32 be = detail::AndS(BitCast(du32, b), uint32_t{0x0000FFFFu}); - const VU32 ao = ShiftRight<16>(BitCast(du32, a)); - const VU32 bo = ShiftRight<16>(BitCast(du32, b)); - return Add(Mul(ae, be), Mul(ao, bo)); +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + const VFromD ae = PromoteEvenTo(df, a); + const VFromD be = PromoteEvenTo(df, b); + const VFromD ao = PromoteOddTo(df, a); + const VFromD bo = PromoteOddTo(df, b); + return MulAdd(ae, be, Mul(ao, bo)); +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(D d32, V16 a, V16 b) { + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) namespace detail { -// Non-overloaded wrapper function so we can define DF32 in template args. -template , - class VF32 = VFromD, - class DBF16 = Repartition>> -HWY_API VF32 ReorderWidenMulAccumulateBF16(Simd df32, - VFromD a, VFromD b, - const VF32 sum0, VF32& sum1) { - const RebindToUnsigned du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 - // Using shift/and instead of Zip leads to the odd/even order that - // RearrangeToOddPlusEven prefers. - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); -} - #define HWY_RVV_WIDEN_MACC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ SHIFT, MLEN, NAME, OP) \ template \ @@ -4790,21 +5890,15 @@ HWY_API VFromD ReorderWidenMulAccumulateU16(D32 d32, VFromD a, } // namespace detail -template -HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, - const VW sum0, VW& sum1) { - return detail::ReorderWidenMulAccumulateBF16(d32, a, b, sum0, sum1); -} - -template -HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, - const VW sum0, VW& sum1) { +template +HWY_API VW ReorderWidenMulAccumulate(D d32, VN a, VN b, const VW sum0, + VW& sum1) { return detail::ReorderWidenMulAccumulateI16(d32, a, b, sum0, sum1); } -template -HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, - const VW sum0, VW& sum1) { +template +HWY_API VW ReorderWidenMulAccumulate(D d32, VN a, VN b, const VW sum0, + VW& sum1) { return detail::ReorderWidenMulAccumulateU16(d32, a, b, sum0, sum1); } @@ -4872,6 +5966,40 @@ HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { } // ------------------------------ Lt128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // The subsequent computations are performed using e8mf8 (8-bit elements with + // a fractional LMUL of 1/8) for the following reasons: + // 1. It is correct for the possible input vector types e64m<1,2,4,8>. This is + // because the resulting mask can occupy at most 1/8 of a full vector when + // using e64m8. + // 2. It can be more efficient than using a full vector or a vector group. + // + // The algorithm computes the result as follows: + // 1. Compute cH | (=H & cL) in the high bits, where cH and cL represent the + // comparison results for the high and low 64-bit elements, respectively. + // 2. Shift the result right by 1 to duplicate the comparison results for the + // low bits. + // 3. Obtain the final result by performing a bitwise OR on the high and low + // bits. + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t ltHL0 = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Lt(a, b))); + const vuint8mf8_t eqHL0 = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t ltLx0 = Add(ltHL0, ltHL0); + const vuint8mf8_t resultHx = detail::AndS(OrAnd(ltHL0, ltLx0, eqHL0), 0xaa); + const vuint8mf8_t resultxL = ShiftRight<1>(resultHx); + const vuint8mf8_t result = Or(resultHx, resultxL); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, detail::ChangeLMUL(du8m1, result)); +} + +#else + template HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { static_assert(IsSame, uint64_t>(), "D must be u64"); @@ -4897,7 +6025,26 @@ HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); } +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + // ------------------------------ Lt128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t ltHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Lt(a, b))); + const vuint8mf8_t ltHx = detail::AndS(ltHL, 0xaa); + const vuint8mf8_t ltxL = ShiftRight<1>(ltHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(ltHx, ltxL))); +} + +#else + template HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { static_assert(IsSame, uint64_t>(), "D must be u64"); @@ -4909,7 +6056,27 @@ HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { return MaskFromVec(OddEven(ltHL, down)); } +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + // ------------------------------ Eq128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t eqHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t eqxH = ShiftRight<1>(eqHL); + const vuint8mf8_t result0L = detail::AndS(And(eqHL, eqxH), 0x55); + const vuint8mf8_t resultH0 = Add(result0L, result0L); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(du8m1, Or(result0L, resultH0))); +} + +#else + template HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { static_assert(IsSame, uint64_t>(), "D must be u64"); @@ -4921,7 +6088,26 @@ HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { return MaskFromVec(eq); } +#endif + // ------------------------------ Eq128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t eqHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t eqHx = detail::AndS(eqHL, 0xaa); + const vuint8mf8_t eqxL = ShiftRight<1>(eqHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(eqHx, eqxL))); +} + +#else + template HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { static_assert(IsSame, uint64_t>(), "D must be u64"); @@ -4930,7 +6116,27 @@ HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { return MaskFromVec(OddEven(eqHL, detail::Slide1Down(eqHL))); } +#endif + // ------------------------------ Ne128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t neHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Ne(a, b))); + const vuint8mf8_t nexH = ShiftRight<1>(neHL); + const vuint8mf8_t result0L = detail::AndS(Or(neHL, nexH), 0x55); + const vuint8mf8_t resultH0 = Add(result0L, result0L); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(du8m1, Or(result0L, resultH0))); +} + +#else + template HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { static_assert(IsSame, uint64_t>(), "D must be u64"); @@ -4941,7 +6147,26 @@ HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { return MaskFromVec(Or(neHL, neLH)); } +#endif + // ------------------------------ Ne128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t neHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Ne(a, b))); + const vuint8mf8_t neHx = detail::AndS(neHL, 0xaa); + const vuint8mf8_t nexL = ShiftRight<1>(neHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(neHx, nexL))); +} + +#else + template HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { static_assert(IsSame, uint64_t>(), "D must be u64"); @@ -4953,6 +6178,8 @@ HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { return MaskFromVec(OddEven(neHL, down)); } +#endif + // ------------------------------ Min128, Max128 (Lt128) template @@ -4994,7 +6221,6 @@ HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { } // ================================================== END MACROS -namespace detail { // for code folding #undef HWY_RVV_AVL #undef HWY_RVV_D #undef HWY_RVV_FOREACH @@ -5055,15 +6281,19 @@ namespace detail { // for code folding #undef HWY_RVV_FOREACH_UI32 #undef HWY_RVV_FOREACH_UI3264 #undef HWY_RVV_FOREACH_UI64 +#undef HWY_RVV_IF_EMULATED_D +#undef HWY_RVV_IF_CAN128_D +#undef HWY_RVV_IF_GE128_D +#undef HWY_RVV_IF_LT128_D #undef HWY_RVV_INSERT_VXRM #undef HWY_RVV_M #undef HWY_RVV_RETM_ARGM +#undef HWY_RVV_RETV_ARGMVV #undef HWY_RVV_RETV_ARGV #undef HWY_RVV_RETV_ARGVS #undef HWY_RVV_RETV_ARGVV #undef HWY_RVV_T #undef HWY_RVV_V -} // namespace detail // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy diff --git a/r/src/vendor/highway/hwy/ops/scalar-inl.h b/r/src/vendor/highway/hwy/ops/scalar-inl.h index bdedf843..a64faf91 100644 --- a/r/src/vendor/highway/hwy/ops/scalar-inl.h +++ b/r/src/vendor/highway/hwy/ops/scalar-inl.h @@ -16,6 +16,7 @@ // Single-element vectors and operations. // External include guard in highway.h - see comment there. +#include #ifndef HWY_NO_LIBCXX #include // sqrtf #endif @@ -53,6 +54,9 @@ struct Vec1 { HWY_INLINE Vec1& operator-=(const Vec1 other) { return *this = (*this - other); } + HWY_INLINE Vec1& operator%=(const Vec1 other) { + return *this = (*this % other); + } HWY_INLINE Vec1& operator&=(const Vec1 other) { return *this = (*this & other); } @@ -101,17 +105,12 @@ HWY_API Vec1 BitCast(DTo /* tag */, Vec1 v) { template > HWY_API Vec1 Zero(D /* tag */) { - Vec1 v; - ZeroBytes(&v.raw); - return v; + return Vec1(ConvertScalarTo(0)); } template using VFromD = decltype(Zero(D())); -// ------------------------------ Tuple (VFromD) -#include "hwy/ops/tuple-inl.h" - // ------------------------------ Set template , typename T2> HWY_API Vec1 Set(D /* tag */, const T2 t) { @@ -137,7 +136,7 @@ HWY_API VFromD ResizeBitCast(D /* tag */, FromV v) { using TFrom = TFromV; using TTo = TFromD; constexpr size_t kCopyLen = HWY_MIN(sizeof(TFrom), sizeof(TTo)); - TTo to = TTo{0}; + TTo to{}; CopyBytes(&v.raw, &to); return VFromD(to); } @@ -156,6 +155,39 @@ HWY_INLINE VFromD ZeroExtendResizeBitCast(FromSizeTag /* from_size_tag */, } // namespace detail +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/, + TFromD /*t8*/, TFromD /*t9*/, + TFromD /*t10*/, TFromD /*t11*/, + TFromD /*t12*/, TFromD /*t13*/, + TFromD /*t14*/, TFromD /*t15*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/) { + return VFromD(t0); +} + // ================================================== LOGICAL // ------------------------------ Not @@ -300,8 +332,7 @@ HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { // ------------------------------ BroadcastSignBit template HWY_API Vec1 BroadcastSignBit(const Vec1 v) { - // This is used inside ShiftRight, so we cannot implement in terms of it. - return v.raw < 0 ? Vec1(T(-1)) : Vec1(0); + return Vec1(ScalarShr(v.raw, sizeof(T) * 8 - 1)); } // ------------------------------ PopulationCount @@ -328,12 +359,12 @@ HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, template HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { - return mask.bits ? yes : Vec1(0); + return mask.bits ? yes : Vec1(ConvertScalarTo(0)); } template HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { - return mask.bits ? Vec1(0) : no; + return mask.bits ? Vec1(ConvertScalarTo(0)) : no; } template @@ -345,11 +376,6 @@ HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { return vi.raw < 0 ? yes : no; } -template -HWY_API Vec1 ZeroIfNegative(const Vec1 v) { - return v.raw < 0 ? Vec1(0) : v; -} - // ------------------------------ Mask logical template @@ -407,6 +433,19 @@ HWY_API Mask1 SetAtOrBeforeFirst(Mask1 /*mask*/) { return Mask1::FromBool(true); } +// ------------------------------ LowerHalfOfMask + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API MFromD LowerHalfOfMask(D /*d*/, MFromD m) { + return m; +} + // ================================================== SHIFTS // ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) @@ -421,35 +460,20 @@ HWY_API Vec1 ShiftLeft(const Vec1 v) { template HWY_API Vec1 ShiftRight(const Vec1 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); -#if __cplusplus >= 202002L - // Signed right shift is now guaranteed to be arithmetic (rounding toward - // negative infinity, i.e. shifting in the sign bit). - return Vec1(static_cast(v.raw >> kBits)); -#else - if (IsSigned()) { - // Emulate arithmetic shift using only logical (unsigned) shifts, because - // signed shifts are still implementation-defined. - using TU = hwy::MakeUnsigned; - const Sisd du; - const TU shifted = static_cast(BitCast(du, v).raw >> kBits); - const TU sign = BitCast(du, BroadcastSignBit(v)).raw; - const size_t sign_shift = - static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); - const TU upper = static_cast(sign << sign_shift); - return BitCast(Sisd(), Vec1(shifted | upper)); - } else { // T is unsigned - return Vec1(static_cast(v.raw >> kBits)); - } -#endif + return Vec1(ScalarShr(v.raw, kBits)); } // ------------------------------ RotateRight (ShiftRight) -template +template HWY_API Vec1 RotateRight(const Vec1 v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(T) * 8; - static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift"); + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; - return Or(ShiftRight(v), + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), ShiftLeft(v)); } @@ -463,26 +487,7 @@ HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { template HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { -#if __cplusplus >= 202002L - // Signed right shift is now guaranteed to be arithmetic (rounding toward - // negative infinity, i.e. shifting in the sign bit). - return Vec1(static_cast(v.raw >> bits)); -#else - if (IsSigned()) { - // Emulate arithmetic shift using only logical (unsigned) shifts, because - // signed shifts are still implementation-defined. - using TU = hwy::MakeUnsigned; - const Sisd du; - const TU shifted = static_cast(BitCast(du, v).raw >> bits); - const TU sign = BitCast(du, BroadcastSignBit(v)).raw; - const size_t sign_shift = - static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); - const TU upper = static_cast(sign << sign_shift); - return BitCast(Sisd(), Vec1(shifted | upper)); - } else { // T is unsigned - return Vec1(static_cast(v.raw >> bits)); - } -#endif + return Vec1(ScalarShr(v.raw, bits)); } // ------------------------------ Shl @@ -528,10 +533,22 @@ HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { // ------------------------------ SumsOf8 +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} HWY_API Vec1 SumsOf8(const Vec1 v) { return Vec1(v.raw); } +// ------------------------------ SumsOf2 + +template +HWY_API Vec1> SumsOf2(const Vec1 v) { + const DFromV d; + const Rebind, decltype(d)> dw; + return PromoteTo(dw, v); +} + // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -590,70 +607,36 @@ HWY_API Vec1 SaturatedSub(const Vec1 a, // Returns (a + b + 1) / 2 -HWY_API Vec1 AverageRound(const Vec1 a, - const Vec1 b) { - return Vec1(static_cast((a.raw + b.raw + 1) / 2)); -} -HWY_API Vec1 AverageRound(const Vec1 a, - const Vec1 b) { - return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +template +HWY_API Vec1 AverageRound(const Vec1 a, const Vec1 b) { + const T a_val = a.raw; + const T b_val = b.raw; + return Vec1(static_cast(ScalarShr(a_val, 1) + ScalarShr(b_val, 1) + + ((a_val | b_val) & 1))); } // ------------------------------ Absolute value template HWY_API Vec1 Abs(const Vec1 a) { - const T i = a.raw; - if (i >= 0 || i == hwy::LimitsMin()) return a; - return Vec1(static_cast(-i & T{-1})); -} -HWY_API Vec1 Abs(Vec1 a) { - int32_t i; - CopyBytes(&a.raw, &i); - i &= 0x7FFFFFFF; - CopyBytes(&i, &a.raw); - return a; -} -HWY_API Vec1 Abs(Vec1 a) { - int64_t i; - CopyBytes(&a.raw, &i); - i &= 0x7FFFFFFFFFFFFFFFL; - CopyBytes(&i, &a.raw); - return a; + return Vec1(ScalarAbs(a.raw)); } // ------------------------------ Min/Max // may be unavailable, so implement our own. -namespace detail { - -static inline float Abs(float f) { - uint32_t i; - CopyBytes<4>(&f, &i); - i &= 0x7FFFFFFFu; - CopyBytes<4>(&i, &f); - return f; -} -static inline double Abs(double f) { - uint64_t i; - CopyBytes<8>(&f, &i); - i &= 0x7FFFFFFFFFFFFFFFull; - CopyBytes<8>(&i, &f); - return f; -} - -static inline bool SignBit(float f) { - uint32_t i; - CopyBytes<4>(&f, &i); - return (i >> 31) != 0; -} -static inline bool SignBit(double f) { - uint64_t i; - CopyBytes<8>(&f, &i); - return (i >> 63) != 0; -} - -} // namespace detail template HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { @@ -662,8 +645,8 @@ HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { template HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { - if (isnan(a.raw)) return b; - if (isnan(b.raw)) return a; + if (ScalarIsNaN(a.raw)) return b; + if (ScalarIsNaN(b.raw)) return a; return Vec1(HWY_MIN(a.raw, b.raw)); } @@ -674,8 +657,8 @@ HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { template HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { - if (isnan(a.raw)) return b; - if (isnan(b.raw)) return a; + if (ScalarIsNaN(a.raw)) return b; + if (ScalarIsNaN(b.raw)) return a; return Vec1(HWY_MAX(a.raw, b.raw)); } @@ -716,21 +699,24 @@ HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { static_cast(b.raw))); } -template +template HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { return Vec1(a.raw / b.raw); } -// Returns the upper 16 bits of a * b in each lane. -HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { - return Vec1(static_cast((a.raw * b.raw) >> 16)); +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + using TW = MakeWide; + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> (sizeof(T) * 8))); } -HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { - // Cast to uint32_t first to prevent overflow. Otherwise the result of - // uint16_t * uint16_t is in "int" which may overflow. In practice the result - // is the same but this way it is also defined. - return Vec1(static_cast( - (static_cast(a.raw) * static_cast(b.raw)) >> 16)); +template +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + T hi; + Mul128(a.raw, b.raw, &hi); + return Vec1(hi); } HWY_API Vec1 MulFixedPoint15(Vec1 a, Vec1 b) { @@ -763,23 +749,23 @@ HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { // ------------------------------ Floating-point multiply-add variants -template +template HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { return mul * x + add; } -template +template HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { return add - mul * x; } -template +template HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { return mul * x - sub; } -template +template HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { return Neg(mul) * x - sub; @@ -842,39 +828,72 @@ HWY_API Vec1 Round(const Vec1 v) { if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN return v; } - const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); - const TI rounded = static_cast(v.raw + bias); - if (rounded == 0) return CopySignToAbs(Vec1(0), v); + const T k0 = ConvertScalarTo(0); + const T bias = ConvertScalarTo(v.raw < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(k0), v); + TI offset = 0; // Round to even - if ((rounded & 1) && detail::Abs(static_cast(rounded) - v.raw) == T(0.5)) { - return Vec1(static_cast(rounded - (v.raw < T(0) ? -1 : 1))); + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = v.raw < k0 ? -1 : 1; } - return Vec1(static_cast(rounded)); + return Vec1(ConvertScalarTo(rounded - offset)); } // Round-to-nearest even. -HWY_API Vec1 NearestInt(const Vec1 v) { - using T = float; - using TI = int32_t; +template +HWY_API Vec1> NearestInt(const Vec1 v) { + using TI = MakeSigned; const T abs = Abs(v).raw; - const bool is_sign = detail::SignBit(v.raw); + const bool is_sign = ScalarSignBit(v.raw); if (!(abs < MantissaEnd())) { // Huge or NaN // Check if too large to cast or NaN - if (!(abs <= static_cast(LimitsMax()))) { + if (!(abs <= ConvertScalarTo(LimitsMax()))) { return Vec1(is_sign ? LimitsMin() : LimitsMax()); } - return Vec1(static_cast(v.raw)); + return Vec1(ConvertScalarTo(v.raw)); } - const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); - const TI rounded = static_cast(v.raw + bias); - if (rounded == 0) return Vec1(0); + const T bias = + ConvertScalarTo(v.raw < ConvertScalarTo(0.0) ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return Vec1(0); + TI offset = 0; // Round to even - if ((rounded & 1) && detail::Abs(static_cast(rounded) - v.raw) == T(0.5)) { - return Vec1(rounded - (is_sign ? -1 : 1)); + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = is_sign ? -1 : 1; } - return Vec1(rounded); + return Vec1(rounded - offset); +} + +// Round-to-nearest even. +template +HWY_API VFromD DemoteToNearestInt(DI32 /*di32*/, const Vec1 v) { + using T = double; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool is_sign = ScalarSignBit(v.raw); + + // Check if too large to cast or NaN + if (!(abs <= ConvertScalarTo(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + + const T bias = + ConvertScalarTo(v.raw < ConvertScalarTo(0.0) ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return Vec1(0); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = is_sign ? -1 : 1; + } + return Vec1(rounded - offset); } template @@ -883,9 +902,9 @@ HWY_API Vec1 Trunc(const Vec1 v) { if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN return v; } - const TI truncated = static_cast(v.raw); + const TI truncated = ConvertScalarTo(v.raw); if (truncated == 0) return CopySignToAbs(Vec1(0), v); - return Vec1(static_cast(truncated)); + return Vec1(ConvertScalarTo(truncated)); } template operator>=(const Vec1 a, const Vec1 b) { template HWY_API Mask1 IsNaN(const Vec1 v) { // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. - MakeUnsigned bits; - CopySameSize(&v, &bits); - bits += bits; - bits >>= 1; // clear sign bit - // NaN if all exponent bits are set and the mantissa is not zero. - return Mask1::FromBool(bits > ExponentMask()); + return Mask1::FromBool(ScalarIsNaN(v.raw)); } +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + HWY_API Mask1 IsInf(const Vec1 v) { const Sisd d; const RebindToUnsigned du; @@ -1126,6 +1147,9 @@ HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, } } +// ------------------------------ Tuples +#include "hwy/ops/inside-inl.h" + // ------------------------------ LoadInterleaved2/3/4 // Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. @@ -1205,8 +1229,9 @@ HWY_API void Stream(const Vec1 v, D d, T* HWY_RESTRICT aligned) { template , typename TI> HWY_API void ScatterOffset(Vec1 v, D d, T* base, Vec1 offset) { static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); - uint8_t* const base8 = reinterpret_cast(base) + offset.raw; - Store(v, d, reinterpret_cast(base8)); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + Store(v, d, reinterpret_cast(addr)); } template , typename TI> @@ -1231,27 +1256,36 @@ HWY_API void MaskedScatterIndex(Vec1 v, Mask1 m, D d, #define HWY_NATIVE_GATHER #endif -template , typename TI> -HWY_API Vec1 GatherOffset(D d, const T* base, Vec1 offset) { - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); +template > +HWY_API Vec1 GatherOffset(D d, const T* base, Vec1> offset) { + HWY_DASSERT(offset.raw >= 0); const intptr_t addr = reinterpret_cast(base) + static_cast(offset.raw); return Load(d, reinterpret_cast(addr)); } -template , typename TI> -HWY_API Vec1 GatherIndex(D d, const T* HWY_RESTRICT base, Vec1 index) { - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); +template > +HWY_API Vec1 GatherIndex(D d, const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); return Load(d, base + index.raw); } -template , typename TI> +template > HWY_API Vec1 MaskedGatherIndex(Mask1 m, D d, const T* HWY_RESTRICT base, - Vec1 index) { - static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + Vec1> index) { + HWY_DASSERT(index.raw >= 0); return MaskedLoad(m, d, base + index.raw); } +template > +HWY_API Vec1 MaskedGatherIndexOr(Vec1 no, Mask1 m, D d, + const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return MaskedLoadOr(no, m, d, base + index.raw); +} + // ================================================== CONVERT // ConvertTo and DemoteTo with floating-point input and integer output truncate @@ -1260,73 +1294,111 @@ HWY_API Vec1 MaskedGatherIndex(Mask1 m, D d, const T* HWY_RESTRICT base, namespace detail { template -HWY_INLINE ToT CastValueForF2IConv(hwy::UnsignedTag /* to_type_tag */, - FromT val) { +HWY_INLINE ToT CastValueForF2IConv(FromT val) { // Prevent ubsan errors when converting float to narrower integer - // If LimitsMax() can be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to LimitsMax(). - - // Otherwise, if LimitsMax() cannot be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to LimitsMax() + 1, which can - // be exactly represented in FromT. - constexpr FromT kSmallestOutOfToTRangePosVal = - (sizeof(ToT) * 8 <= static_cast(MantissaBits()) + 1) - ? static_cast(LimitsMax()) - : static_cast( - static_cast(ToT{1} << (sizeof(ToT) * 8 - 1)) * FromT(2)); - - if (detail::SignBit(val)) { - return ToT{0}; - } else if (IsInf(Vec1(val)).bits || - val >= kSmallestOutOfToTRangePosVal) { - return LimitsMax(); - } else { - return static_cast(val); - } -} - -template -HWY_INLINE ToT CastValueForF2IConv(hwy::SignedTag /* to_type_tag */, - FromT val) { - // Prevent ubsan errors when converting float to narrower integer - - // If LimitsMax() can be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to LimitsMax(). - - // Otherwise, if LimitsMax() cannot be exactly represented in FromT, - // kSmallestOutOfToTRangePosVal is equal to -LimitsMin(), which can - // be exactly represented in FromT. - constexpr FromT kSmallestOutOfToTRangePosVal = - (sizeof(ToT) * 8 <= static_cast(MantissaBits()) + 2) - ? static_cast(LimitsMax()) - : static_cast(-static_cast(LimitsMin())); - - if (IsInf(Vec1(val)).bits || - detail::Abs(val) >= kSmallestOutOfToTRangePosVal) { - return detail::SignBit(val) ? LimitsMin() : LimitsMax(); - } else { - return static_cast(val); - } + using FromTU = MakeUnsigned; + using ToTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(static_cast(LimitsMax()) + + static_cast(ScalarSignBit(val))); } template HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { - return static_cast(val); + return ConvertScalarTo(val); } template -HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag to_type_tag, float val) { - return CastValueForF2IConv(to_type_tag, val); +HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); } template -HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag to_type_tag, float val) { - return CastValueForF2IConv(to_type_tag, val); +HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +// If val is within the range of ToT, CastValueForInRangeF2IConv(val) +// returns static_cast(val) +// +// Otherwise, CastValueForInRangeF2IConv(val) returns an +// implementation-defined result if val is not within the range of ToT. +template +HWY_INLINE ToT CastValueForInRangeF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(LimitsMin()); } } // namespace detail +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + template , typename TFrom> HWY_API Vec1 PromoteTo(DTo /* tag */, Vec1 from) { static_assert(sizeof(TTo) > sizeof(TFrom), "Not promoting"); @@ -1335,6 +1407,18 @@ HWY_API Vec1 PromoteTo(DTo /* tag */, Vec1 from) { detail::CastValueForPromoteTo(hwy::TypeTag(), from.raw)); } +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(DTo /* tag */, Vec1 from) { + using TTo = TFromD; + return Vec1(detail::CastValueForInRangeF2IConv(from.raw)); +} + // MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, // so we overload for TFrom=double and TTo={float,int32_t}. template @@ -1342,16 +1426,15 @@ HWY_API Vec1 DemoteTo(D /* tag */, Vec1 from) { // Prevent ubsan errors when converting float to narrower integer/float if (IsInf(from).bits || Abs(from).raw > static_cast(HighestValue())) { - return Vec1(detail::SignBit(from.raw) ? LowestValue() - : HighestValue()); + return Vec1(ScalarSignBit(from.raw) ? LowestValue() + : HighestValue()); } return Vec1(static_cast(from.raw)); } template HWY_API VFromD DemoteTo(D /* tag */, Vec1 from) { // Prevent ubsan errors when converting int32_t to narrower integer/int32_t - return Vec1>(detail::CastValueForF2IConv>( - hwy::TypeTag>(), from.raw)); + return Vec1>(detail::CastValueForF2IConv>(from.raw)); } template , typename TFrom, @@ -1365,15 +1448,30 @@ HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { return Vec1(static_cast(from.raw)); } +// Disable the default unsigned to signed DemoteTo implementation in +// generic_ops-inl.h on SCALAR as the SCALAR target has a target-specific +// implementation of the unsigned to signed DemoteTo op and as ReorderDemote2To +// is not supported on the SCALAR target + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + template , typename TFrom, - HWY_IF_UNSIGNED(TFrom), HWY_IF_UNSIGNED_D(DTo)> + HWY_IF_UNSIGNED(TFrom), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DTo)> HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { static_assert(!IsFloat(), "TFrom=double are handled above"); static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + const auto max = static_cast>(LimitsMax()); + // Int to int: choose closest value in TTo to `from` (avoids UB) - from.raw = HWY_MIN(from.raw, LimitsMax()); - return Vec1(static_cast(from.raw)); + return Vec1(static_cast(HWY_MIN(from.raw, max))); } template , typename TFrom, @@ -1383,6 +1481,19 @@ HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { return Vec1(static_cast(from.raw)); } +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D32 /*d32*/, + VFromD> v) { + using TTo = TFromD; + return Vec1(detail::CastValueForInRangeF2IConv(v.raw)); +} + // Per-target flag to prevent generic_ops-inl.h from defining f16 conversions; // use this scalar version to verify the vector implementation. #ifdef HWY_NATIVE_F16C @@ -1401,11 +1512,22 @@ HWY_API Vec1 PromoteTo(D d, const Vec1 v) { return Set(d, F32FromBF16(v.raw)); } +template +HWY_API VFromD PromoteEvenTo(DTo d_to, Vec1 v) { + return PromoteTo(d_to, v); +} + template HWY_API Vec1 DemoteTo(D /* tag */, const Vec1 v) { return Vec1(F16FromF32(v.raw)); } +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + template HWY_API Vec1 DemoteTo(D d, const Vec1 v) { return Set(d, BF16FromF32(v.raw)); @@ -1416,8 +1538,7 @@ template , typename TFrom, HWY_API Vec1 ConvertTo(DTo /* tag */, Vec1 from) { static_assert(sizeof(TTo) == sizeof(TFrom), "Should have same size"); // float## -> int##: return closest representable value. - return Vec1( - detail::CastValueForF2IConv(hwy::TypeTag(), from.raw)); + return Vec1(detail::CastValueForF2IConv(from.raw)); } template , typename TFrom, @@ -1428,6 +1549,19 @@ HWY_API Vec1 ConvertTo(DTo /* tag */, Vec1 from) { return Vec1(static_cast(from.raw)); } +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, VFromD> v) { + using TTo = TFromD; + return VFromD(detail::CastValueForInRangeF2IConv(v.raw)); +} + HWY_API Vec1 U8FromU32(const Vec1 v) { return DemoteTo(Sisd(), v); } @@ -1792,6 +1926,11 @@ HWY_API Mask1 LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { return Mask1::FromBool((bits[0] & 1) != 0); } +template +HWY_API MFromD Dup128MaskFromMaskBits(D /*d*/, unsigned mask_bits) { + return MFromD::FromBool((mask_bits & 1) != 0); +} + // `p` points to at least 8 writable bytes. template > HWY_API size_t StoreMaskBits(D d, const Mask1 mask, uint8_t* bits) { @@ -1910,6 +2049,35 @@ HWY_API Vec1 WidenMulPairwiseAdd(D32 /* tag */, Vec1 a, return Vec1(a.raw * b.raw); } +// ------------------------------ SatWidenMulAccumFixedPoint +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + // Multiplying static_cast(a.raw) by static_cast(b.raw) + // followed by an addition of the product is okay as + // (a.raw * b.raw * 2) is between -2147418112 and 2147483648 and as + // a.raw * b.raw * 2 can only overflow an int32_t if both a.raw and b.raw are + // equal to -32768. + + const VFromD product(static_cast(a.raw) * + static_cast(b.raw)); + const VFromD product2 = Add(product, product); + + const auto mul_overflow = + VecFromMask(di32, Eq(product2, Set(di32, LimitsMin()))); + + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), + Add(product2, mul_overflow)); +} + // ------------------------------ SatWidenMulPairwiseAdd #ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD @@ -1937,6 +2105,12 @@ HWY_API Vec1 SatWidenMulPairwiseAdd(DI16 /* tag */, Vec1 a, // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + template HWY_API Vec1 ReorderWidenMulAccumulate(D32 /* tag */, Vec1 a, Vec1 b, @@ -1971,23 +2145,7 @@ HWY_API Vec1 RearrangeToOddPlusEven(Vec1 sum0, Vec1 /* sum1 */) { // ================================================== REDUCTIONS -// Sum of all lanes, i.e. the only one. -template > -HWY_API Vec1 SumOfLanes(D /* tag */, const Vec1 v) { - return v; -} -template > -HWY_API T ReduceSum(D /* tag */, const Vec1 v) { - return GetLane(v); -} -template > -HWY_API Vec1 MinOfLanes(D /* tag */, const Vec1 v) { - return v; -} -template > -HWY_API Vec1 MaxOfLanes(D /* tag */, const Vec1 v) { - return v; -} +// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum. // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE diff --git a/r/src/vendor/highway/hwy/ops/set_macros-inl.h b/r/src/vendor/highway/hwy/ops/set_macros-inl.h index d8bed3e2..f955f936 100644 --- a/r/src/vendor/highway/hwy/ops/set_macros-inl.h +++ b/r/src/vendor/highway/hwy/ops/set_macros-inl.h @@ -1,5 +1,7 @@ // Copyright 2020 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates // SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,9 +43,31 @@ #undef HWY_HAVE_FLOAT64 #undef HWY_MEM_OPS_MIGHT_FAULT #undef HWY_NATIVE_FMA +#undef HWY_NATIVE_DOT_BF16 #undef HWY_CAP_GE256 #undef HWY_CAP_GE512 +#undef HWY_TARGET_IS_SVE +#if HWY_TARGET & HWY_ALL_SVE +#define HWY_TARGET_IS_SVE 1 +#else +#define HWY_TARGET_IS_SVE 0 +#endif + +#undef HWY_TARGET_IS_NEON +#if HWY_TARGET & HWY_ALL_NEON +#define HWY_TARGET_IS_NEON 1 +#else +#define HWY_TARGET_IS_NEON 0 +#endif + +#undef HWY_TARGET_IS_PPC +#if HWY_TARGET & HWY_ALL_PPC +#define HWY_TARGET_IS_PPC 1 +#else +#define HWY_TARGET_IS_PPC 0 +#endif + // Supported on all targets except RVV (requires GCC 14 or upcoming Clang) #if HWY_TARGET == HWY_RVV && \ ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) || \ @@ -116,7 +140,21 @@ ",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avx512vnni,avx512bitalg," \ "avx512vpopcntdq,gfni" -#define HWY_TARGET_STR_AVX3_SPR HWY_TARGET_STR_AVX3_DL ",avx512fp16" +// Force-disable for compilers that do not properly support avx512bf16. +#if !defined(HWY_AVX3_DISABLE_AVX512BF16) && \ + (HWY_COMPILER_CLANGCL || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 900)) +#define HWY_AVX3_DISABLE_AVX512BF16 +#endif + +#if !defined(HWY_AVX3_DISABLE_AVX512BF16) +#define HWY_TARGET_STR_AVX3_ZEN4 HWY_TARGET_STR_AVX3_DL ",avx512bf16" +#else +#define HWY_TARGET_STR_AVX3_ZEN4 HWY_TARGET_STR_AVX3_DL +#endif + +#define HWY_TARGET_STR_AVX3_SPR HWY_TARGET_STR_AVX3_ZEN4 ",avx512fp16" #if defined(HWY_DISABLE_PPC8_CRYPTO) #define HWY_TARGET_STR_PPC8_CRYPTO "" @@ -131,9 +169,21 @@ #if HWY_COMPILER_CLANG #define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",power10-vector" #else -#define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",cpu=power10" +// See #1707 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=102059#c35. +// When the baseline is PPC 8 or 9, inlining functions such as PreventElision +// into PPC10 code fails because PPC10 defaults to no-htm and is thus worse than +// the baseline, which has htm. We cannot have pragma target on functions +// outside HWY_NAMESPACE such as those in base.h. It would be possible for users +// to set -mno-htm globally, but we can also work around this at the library +// level by claiming that PPC10 still has HTM, thus avoiding the mismatch. This +// seems to be safe because HTM uses builtins rather than modifying codegen, see +// https://gcc.gnu.org/legacy-ml/gcc-patches/2013-07/msg00167.html. +#define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",cpu=power10,htm" #endif +#define HWY_TARGET_STR_Z14 "arch=z14" +#define HWY_TARGET_STR_Z15 "arch=z15" + // Before include guard so we redefine HWY_TARGET_STR on each include, // governed by the current HWY_TARGET. @@ -152,6 +202,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -171,6 +222,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -191,6 +243,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -216,6 +269,7 @@ #else #define HWY_NATIVE_FMA 1 #endif +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 1 #define HWY_CAP_GE512 0 @@ -233,7 +287,9 @@ #define HWY_HAVE_SCALABLE 0 #define HWY_HAVE_INTEGER64 1 -#if (HWY_TARGET == HWY_AVX3_SPR) && 0 // TODO(janwas): enable after testing +#if HWY_TARGET == HWY_AVX3_SPR && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1901) && \ + HWY_HAVE_SCALAR_F16_TYPE #define HWY_HAVE_FLOAT16 1 #else #define HWY_HAVE_FLOAT16 0 @@ -241,6 +297,11 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 0 #define HWY_NATIVE_FMA 1 +#if (HWY_TARGET <= HWY_AVX3_ZEN4) && !defined(HWY_AVX3_DISABLE_AVX512BF16) +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif #define HWY_CAP_GE256 1 #define HWY_CAP_GE512 1 @@ -257,8 +318,7 @@ #elif HWY_TARGET == HWY_AVX3_ZEN4 #define HWY_NAMESPACE N_AVX3_ZEN4 -// Currently the same as HWY_AVX3_DL: both support Icelake. -#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_DL +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_ZEN4 #elif HWY_TARGET == HWY_AVX3_SPR @@ -271,8 +331,7 @@ //----------------------------------------------------------------------------- // PPC8, PPC9, PPC10 -#elif HWY_TARGET == HWY_PPC8 || HWY_TARGET == HWY_PPC9 || \ - HWY_TARGET == HWY_PPC10 +#elif HWY_TARGET_IS_PPC #define HWY_ALIGN alignas(16) #define HWY_MAX_BYTES 16 @@ -284,6 +343,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -304,11 +364,43 @@ #else #error "Logic error" -#endif // HWY_TARGET == HWY_PPC10 +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- +// Z14, Z15 +#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_Z14 + +#define HWY_NAMESPACE N_Z14 +#define HWY_TARGET_STR HWY_TARGET_STR_Z14 + +#elif HWY_TARGET == HWY_Z15 + +#define HWY_NAMESPACE N_Z15 +#define HWY_TARGET_STR HWY_TARGET_STR_Z15 + +#else +#error "Logic error" +#endif // HWY_TARGET == HWY_Z15 //----------------------------------------------------------------------------- // NEON -#elif HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_NEON_WITHOUT_AES +#elif HWY_TARGET_IS_NEON #define HWY_ALIGN alignas(16) #define HWY_MAX_BYTES 16 @@ -316,7 +408,7 @@ #define HWY_HAVE_SCALABLE 0 #define HWY_HAVE_INTEGER64 1 -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || HWY_TARGET == HWY_NEON_BF16 #define HWY_HAVE_FLOAT16 1 #else #define HWY_HAVE_FLOAT16 0 @@ -330,20 +422,29 @@ #define HWY_MEM_OPS_MIGHT_FAULT 1 -#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#if defined(__ARM_FEATURE_FMA) || defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 #define HWY_NATIVE_FMA 1 #else #define HWY_NATIVE_FMA 0 #endif +#if HWY_NEON_HAVE_F32_TO_BF16C || HWY_TARGET == HWY_NEON_BF16 +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 #if HWY_TARGET == HWY_NEON_WITHOUT_AES #define HWY_NAMESPACE N_NEON_WITHOUT_AES -#else +#elif HWY_TARGET == HWY_NEON #define HWY_NAMESPACE N_NEON -#endif +#elif HWY_TARGET == HWY_NEON_BF16 +#define HWY_NAMESPACE N_NEON_BF16 +#else +#error "Logic error, missing case" +#endif // HWY_TARGET // Can use pragmas instead of -march compiler flag #if HWY_HAVE_RUNTIME_DISPATCH @@ -358,21 +459,43 @@ #else // !HWY_ARCH_ARM_V7 +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1300) +// GCC 12 or earlier and Clang 12 or earlier require +crypto be added to the +// target string to enable AArch64 AES intrinsics +#define HWY_TARGET_STR_NEON "+crypto" +#else +#define HWY_TARGET_STR_NEON "+aes" +#endif + +// Clang >= 16 requires +fullfp16 instead of fp16, but Apple Clang 15 = 1600 +// fails to parse unless the string starts with armv8, whereas 1700 refuses it. +#if HWY_COMPILER_CLANG >= 1700 +#define HWY_TARGET_STR_FP16 "+fullfp16" +#elif HWY_COMPILER_CLANG >= 1600 && defined(__apple_build_version__) +#define HWY_TARGET_STR_FP16 "armv8.4-a+fullfp16" +#else +#define HWY_TARGET_STR_FP16 "+fp16" +#endif + #if HWY_TARGET == HWY_NEON_WITHOUT_AES // Do not define HWY_TARGET_STR (no pragma). +#elif HWY_TARGET == HWY_NEON +#define HWY_TARGET_STR HWY_TARGET_STR_NEON +#elif HWY_TARGET == HWY_NEON_BF16 +#define HWY_TARGET_STR HWY_TARGET_STR_FP16 "+bf16+dotprod" HWY_TARGET_STR_NEON #else -#define HWY_TARGET_STR "+crypto" -#endif // HWY_TARGET == HWY_NEON_WITHOUT_AES +#error "Logic error, missing case" +#endif // HWY_TARGET -#endif // HWY_ARCH_ARM_V7 +#endif // !HWY_ARCH_ARM_V7 #else // !HWY_HAVE_RUNTIME_DISPATCH // HWY_TARGET_STR remains undefined #endif //----------------------------------------------------------------------------- // SVE[2] -#elif HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE || \ - HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#elif HWY_TARGET_IS_SVE // SVE only requires lane alignment, not natural alignment of the entire vector. #define HWY_ALIGN alignas(8) @@ -382,10 +505,15 @@ #define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T)) #define HWY_HAVE_INTEGER64 1 -#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT16 1 #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 0 #define HWY_NATIVE_FMA 1 +#if HWY_SVE_HAVE_BF16_FEATURE +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -410,11 +538,17 @@ // Can use pragmas instead of -march compiler flag #if HWY_HAVE_RUNTIME_DISPATCH #if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 -#define HWY_TARGET_STR "+sve2-aes" -#else +// Static dispatch with -march=armv8-a+sve2+aes, or no baseline, hence dynamic +// dispatch, which checks for AES support at runtime. +#if defined(__ARM_FEATURE_SVE2_AES) || (HWY_BASELINE_SVE2 == 0) +#define HWY_TARGET_STR "+sve2+sve2-aes,+sve" +#else // SVE2 without AES +#define HWY_TARGET_STR "+sve2,+sve" +#endif +#else // not SVE2 target #define HWY_TARGET_STR "+sve" #endif -#else +#else // !HWY_HAVE_RUNTIME_DISPATCH // HWY_TARGET_STR remains undefined #endif @@ -432,6 +566,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -450,9 +585,10 @@ #define HWY_HAVE_SCALABLE 0 #define HWY_HAVE_INTEGER64 1 #define HWY_HAVE_FLOAT16 0 -#define HWY_HAVE_FLOAT64 0 +#define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 1 #define HWY_CAP_GE512 0 @@ -480,10 +616,11 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 0 #define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 -#if defined(__riscv_zvfh) +#if HWY_RVV_HAVE_F16_VEC #define HWY_HAVE_FLOAT16 1 #else #define HWY_HAVE_FLOAT16 0 @@ -491,8 +628,12 @@ #define HWY_NAMESPACE N_RVV +#if HWY_COMPILER_CLANG >= 1900 +// https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc#181-zvl-minimum-vector-length-standard-extensions +#define HWY_TARGET_STR "Zvl128b,Zve64d" +#else // HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. -// (rv64gcv is not a valid target) +#endif //----------------------------------------------------------------------------- // EMU128 @@ -508,6 +649,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 1 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -529,6 +671,7 @@ #define HWY_HAVE_FLOAT64 1 #define HWY_MEM_OPS_MIGHT_FAULT 0 #define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 #define HWY_CAP_GE256 0 #define HWY_CAP_GE512 0 @@ -540,6 +683,14 @@ #pragma message("HWY_TARGET does not match any known target") #endif // HWY_TARGET +//----------------------------------------------------------------------------- + +// Sanity check: if we have f16 vector support, then base.h should also be +// using a built-in type for f16 scalars. +#if HWY_HAVE_FLOAT16 && !HWY_HAVE_SCALAR_F16_TYPE +#error "Logic error: f16 vectors but no scalars" +#endif + // Override this to 1 in asan/msan builds, which will still fault. #if HWY_IS_ASAN || HWY_IS_MSAN #undef HWY_MEM_OPS_MIGHT_FAULT diff --git a/r/src/vendor/highway/hwy/ops/shared-inl.h b/r/src/vendor/highway/hwy/ops/shared-inl.h index d322281d..bb18020e 100644 --- a/r/src/vendor/highway/hwy/ops/shared-inl.h +++ b/r/src/vendor/highway/hwy/ops/shared-inl.h @@ -26,6 +26,7 @@ #endif #include "hwy/detect_compiler_arch.h" +#include "hwy/detect_targets.h" // Separate header because foreach_target.h re-enables its include guard. #include "hwy/ops/set_macros-inl.h" @@ -38,7 +39,9 @@ // We are covered by the highway.h include guard, but generic_ops-inl.h // includes this again #if HWY_IDE. -#if defined(HIGHWAY_HWY_OPS_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) +// clang-format off +#if defined(HIGHWAY_HWY_OPS_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on #ifdef HIGHWAY_HWY_OPS_SHARED_TOGGLE #undef HIGHWAY_HWY_OPS_SHARED_TOGGLE #else @@ -59,6 +62,10 @@ namespace HWY_NAMESPACE { // We therefore pass by const& only on GCC and (Windows or aarch64). This alias // must be used for all vector/mask parameters of functions marked HWY_NOINLINE, // and possibly also other functions that are not inlined. +// +// Even better is to avoid passing vector arguments to non-inlined functions, +// because the SVE and RISC-V ABIs are still works in progress and may lead to +// incorrect codegen. #if HWY_COMPILER_GCC_ACTUAL && (HWY_OS_WIN || HWY_ARCH_ARM_A64) template using VecArg = const V&; @@ -69,27 +76,75 @@ using VecArg = V; namespace detail { -// Primary template: default is no change for all but f16. template struct NativeLaneTypeT { using type = T; }; - template <> struct NativeLaneTypeT { - using type = hwy::float16_t::Raw; +#if HWY_HAVE_SCALAR_F16_TYPE + using type = hwy::float16_t::Native; +#else + using type = uint16_t; +#endif }; - template <> struct NativeLaneTypeT { - using type = hwy::bfloat16_t::Raw; +#if HWY_HAVE_SCALAR_BF16_TYPE + using type = hwy::bfloat16_t::Native; +#else + using type = uint16_t; +#endif }; -// Evaluates to the type expected by intrinsics given the Highway lane type T. -// This is usually the same, but differs for our wrapper types [b]float16_t. +// The type expected by intrinsics for the given Highway lane type T. This +// usually matches T, but differs for our wrapper types [b]float16_t. Use this +// only when defining intrinsic wrappers, and NOT for casting, which is UB. template using NativeLaneType = typename NativeLaneTypeT::type; +// Returns the same pointer after changing type to NativeLaneType. Use this only +// for wrapper functions that call intrinsics (e.g. load/store) where some of +// the overloads expect _Float16* or __bf16* arguments. For non-special floats, +// this returns the same pointer and type. +// +// This makes use of the fact that a wrapper struct is pointer-interconvertible +// with its first member (a union), thus also with the union members. Do NOT +// call both this and U16LanePointer on the same object - they access different +// union members, and this is not guaranteed to be safe. +template +HWY_INLINE T* NativeLanePointer(T* p) { + return p; +} +template >, + HWY_IF_F16(T)> +HWY_INLINE constexpr If(), const NT*, NT*> NativeLanePointer(T* p) { +#if HWY_HAVE_SCALAR_F16_TYPE + return &p->native; +#else + return &p->bits; +#endif +} +template >, + HWY_IF_BF16(T)> +HWY_INLINE constexpr If(), const NT*, NT*> NativeLanePointer(T* p) { +#if HWY_HAVE_SCALAR_BF16_TYPE + return &p->native; +#else + return &p->bits; +#endif +} + +// Returns a pointer to the u16 member of our [b]float16_t wrapper structs. +// Use this in Highway targets that lack __bf16 intrinsics; for storing to +// memory, we BitCast vectors to u16 and write to the pointer returned here. +// Do NOT call both this and U16LanePointer on the same object - they access +// different union members, and this is not guaranteed to be safe. +template +HWY_INLINE If(), const uint16_t*, uint16_t*> U16LanePointer(T* p) { + return &p->bits; +} + // Returns N * 2^pow2. N is the number of lanes in a full vector and pow2 the // desired fraction or multiple of it, see Simd<>. `pow2` is most often in // [-3, 3] but can also be lower for user-specified fractions. @@ -97,6 +152,16 @@ constexpr size_t ScaleByPower(size_t N, int pow2) { return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); } +template +HWY_INLINE void MaybePoison(T* HWY_RESTRICT unaligned, size_t count) { +#if HWY_IS_MSAN + __msan_poison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + template HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) { // Workaround for MSAN not marking compressstore as initialized (b/233326619) @@ -151,6 +216,13 @@ struct Simd { private: static_assert(sizeof(Lane) <= 8, "Lanes are up to 64-bit"); + static_assert(IsSame>(), + "Lane must not be a reference type, const-qualified type, or " + "volatile-qualified type"); + static_assert(IsIntegerLaneType() || IsFloat() || + IsSpecialFloat(), + "IsIntegerLaneType(), IsFloat(), or IsSpecialFloat() " + "must be true"); // 20 bits are sufficient for any HWY_MAX_BYTES. This is the 'normal' value of // N when kFrac == 0, otherwise it is one (see FracN). static constexpr size_t kWhole = N & 0xFFFFF; @@ -185,11 +257,14 @@ struct Simd { // macro required by MSVC. static constexpr size_t kPrivateLanes = HWY_MAX(size_t{1}, detail::ScaleByPower(kWhole, kPow2 - kFrac)); + // Do not use this directly - only 'public' so it is visible from the accessor + // macro required by MSVC. + static constexpr int kPrivatePow2 = kPow2; constexpr size_t MaxLanes() const { return kPrivateLanes; } constexpr size_t MaxBytes() const { return kPrivateLanes * sizeof(Lane); } constexpr size_t MaxBlocks() const { return (MaxBytes() + 15) / 16; } - // For SFINAE on RVV. + // For SFINAE (HWY_IF_POW2_GT_D). constexpr int Pow2() const { return kPow2; } // ------------------------------ Changing lane type or count @@ -371,6 +446,10 @@ using TFromD = typename D::T; // MSVC workaround: use static constant directly instead of a function. #define HWY_MAX_LANES_D(D) D::kPrivateLanes +// Same as D().Pow2(), but this is too complex for SFINAE with MSVC, so we use a +// static constant directly. +#define HWY_POW2_D(D) D::kPrivatePow2 + // Non-macro form of HWY_MAX_LANES_D in case that is preferable. WARNING: the // macro form may be required for MSVC, which has limitations on deducing // arguments. @@ -411,6 +490,13 @@ using RepartitionToWide = Repartition>, D>; template using RepartitionToNarrow = Repartition>, D>; +// Shorthand for applying RepartitionToWide twice (for 8/16-bit types). +template +using RepartitionToWideX2 = RepartitionToWide>; +// Shorthand for applying RepartitionToWide three times (for 8-bit types). +template +using RepartitionToWideX3 = RepartitionToWide>; + // Tag for the same lane type as D, but half the lanes. template using Half = typename D::Half; @@ -447,85 +533,152 @@ using BlockDFromD = Simd, HWY_MIN(16 / sizeof(TFromD), HWY_MAX_LANES_D(D)), 0>; #endif +// Returns whether `ptr` is a multiple of `Lanes(d)` elements. +template +HWY_API bool IsAligned(D d, T* ptr) { + const size_t N = Lanes(d); + return reinterpret_cast(ptr) % (N * sizeof(T)) == 0; +} + // ------------------------------ Choosing overloads (SFINAE) // Same as base.h macros but with a Simd argument instead of T. -#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(TFromD) -#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(TFromD) -#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(TFromD) -#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(TFromD) -#define HWY_IF_FLOAT3264_D(D) HWY_IF_FLOAT3264(TFromD) -#define HWY_IF_NOT_FLOAT3264_D(D) HWY_IF_NOT_FLOAT3264(TFromD) -#define HWY_IF_SPECIAL_FLOAT_D(D) HWY_IF_SPECIAL_FLOAT(TFromD) -#define HWY_IF_NOT_SPECIAL_FLOAT_D(D) HWY_IF_NOT_SPECIAL_FLOAT(TFromD) -#define HWY_IF_FLOAT_OR_SPECIAL_D(D) HWY_IF_FLOAT_OR_SPECIAL(TFromD) +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_UNSIGNED_D(D) \ + HWY_IF_NOT_UNSIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT3264_D(D) HWY_IF_FLOAT3264(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT3264_D(D) \ + HWY_IF_NOT_FLOAT3264(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_SPECIAL_FLOAT_D(D) \ + HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_SPECIAL_FLOAT_D(D) \ + HWY_IF_NOT_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT_OR_SPECIAL_D(D) \ + HWY_IF_FLOAT_OR_SPECIAL(hwy::HWY_NAMESPACE::TFromD) #define HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D) \ - HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD) + HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromD) -#define HWY_IF_T_SIZE_D(D, bytes) HWY_IF_T_SIZE(TFromD, bytes) -#define HWY_IF_NOT_T_SIZE_D(D, bytes) HWY_IF_NOT_T_SIZE(TFromD, bytes) +#define HWY_IF_T_SIZE_D(D, bytes) \ + HWY_IF_T_SIZE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_NOT_T_SIZE_D(D, bytes) \ + HWY_IF_NOT_T_SIZE(hwy::HWY_NAMESPACE::TFromD, bytes) #define HWY_IF_T_SIZE_ONE_OF_D(D, bit_array) \ - HWY_IF_T_SIZE_ONE_OF(TFromD, bit_array) + HWY_IF_T_SIZE_ONE_OF(hwy::HWY_NAMESPACE::TFromD, bit_array) +#define HWY_IF_T_SIZE_LE_D(D, bytes) \ + HWY_IF_T_SIZE_LE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_T_SIZE_GT_D(D, bytes) \ + HWY_IF_T_SIZE_GT(hwy::HWY_NAMESPACE::TFromD, bytes) #define HWY_IF_LANES_D(D, lanes) HWY_IF_LANES(HWY_MAX_LANES_D(D), lanes) #define HWY_IF_LANES_LE_D(D, lanes) HWY_IF_LANES_LE(HWY_MAX_LANES_D(D), lanes) #define HWY_IF_LANES_GT_D(D, lanes) HWY_IF_LANES_GT(HWY_MAX_LANES_D(D), lanes) -#define HWY_IF_LANES_PER_BLOCK_D(D, lanes) \ - HWY_IF_LANES_PER_BLOCK( \ - TFromD, HWY_MIN(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD)), lanes) - +#define HWY_IF_LANES_PER_BLOCK_D(D, lanes) \ + HWY_IF_LANES_PER_BLOCK(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), \ + lanes) + +#if HWY_COMPILER_MSVC +#define HWY_IF_POW2_LE_D(D, pow2) \ + hwy::EnableIf* = nullptr +#define HWY_IF_POW2_GT_D(D, pow2) \ + hwy::EnableIf<(HWY_POW2_D(D) > pow2)>* = nullptr +#else #define HWY_IF_POW2_LE_D(D, pow2) hwy::EnableIf* = nullptr #define HWY_IF_POW2_GT_D(D, pow2) hwy::EnableIf<(D().Pow2() > pow2)>* = nullptr +#endif // HWY_COMPILER_MSVC -#define HWY_IF_U8_D(D) hwy::EnableIf, uint8_t>()>* = nullptr -#define HWY_IF_U16_D(D) hwy::EnableIf, uint16_t>()>* = nullptr -#define HWY_IF_U32_D(D) hwy::EnableIf, uint32_t>()>* = nullptr -#define HWY_IF_U64_D(D) hwy::EnableIf, uint64_t>()>* = nullptr +#define HWY_IF_U8_D(D) HWY_IF_U8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U16_D(D) HWY_IF_U16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U32_D(D) HWY_IF_U32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U64_D(D) HWY_IF_U64(hwy::HWY_NAMESPACE::TFromD) -#define HWY_IF_I8_D(D) hwy::EnableIf, int8_t>()>* = nullptr -#define HWY_IF_I16_D(D) hwy::EnableIf, int16_t>()>* = nullptr -#define HWY_IF_I32_D(D) hwy::EnableIf, int32_t>()>* = nullptr -#define HWY_IF_I64_D(D) hwy::EnableIf, int64_t>()>* = nullptr +#define HWY_IF_I8_D(D) HWY_IF_I8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I16_D(D) HWY_IF_I16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I32_D(D) HWY_IF_I32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I64_D(D) HWY_IF_I64(hwy::HWY_NAMESPACE::TFromD) // Use instead of HWY_IF_T_SIZE_D to avoid ambiguity with float16_t/float/double // overloads. -#define HWY_IF_UI16_D(D) HWY_IF_UI16(TFromD) -#define HWY_IF_UI32_D(D) HWY_IF_UI32(TFromD) -#define HWY_IF_UI64_D(D) HWY_IF_UI64(TFromD) +#define HWY_IF_UI8_D(D) HWY_IF_UI8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI16_D(D) HWY_IF_UI16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI32_D(D) HWY_IF_UI32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI64_D(D) HWY_IF_UI64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_BF16_D(D) HWY_IF_BF16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_BF16_D(D) HWY_IF_NOT_BF16(hwy::HWY_NAMESPACE::TFromD) -#define HWY_IF_BF16_D(D) HWY_IF_BF16(TFromD) -#define HWY_IF_F16_D(D) HWY_IF_F16(TFromD) -#define HWY_IF_F32_D(D) hwy::EnableIf, float>()>* = nullptr -#define HWY_IF_F64_D(D) hwy::EnableIf, double>()>* = nullptr +#define HWY_IF_F16_D(D) HWY_IF_F16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_F16_D(D) HWY_IF_NOT_F16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_F32_D(D) HWY_IF_F32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_F64_D(D) HWY_IF_F64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_V_SIZE_D(D) \ + (HWY_MAX_LANES_D(D) * sizeof(hwy::HWY_NAMESPACE::TFromD)) #define HWY_IF_V_SIZE_D(D, bytes) \ - HWY_IF_V_SIZE(TFromD, HWY_MAX_LANES_D(D), bytes) + HWY_IF_V_SIZE(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) #define HWY_IF_V_SIZE_LE_D(D, bytes) \ - HWY_IF_V_SIZE_LE(TFromD, HWY_MAX_LANES_D(D), bytes) + HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) #define HWY_IF_V_SIZE_GT_D(D, bytes) \ - HWY_IF_V_SIZE_GT(TFromD, HWY_MAX_LANES_D(D), bytes) + HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) // Same, but with a vector argument. ops/*-inl.h define their own TFromV. -#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV) -#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV) -#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV) -#define HWY_IF_NOT_FLOAT_V(V) HWY_IF_NOT_FLOAT(TFromV) -#define HWY_IF_SPECIAL_FLOAT_V(V) HWY_IF_SPECIAL_FLOAT(TFromV) +#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_UNSIGNED_V(V) \ + HWY_IF_NOT_UNSIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_FLOAT_V(V) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_SPECIAL_FLOAT_V(V) \ + HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromV) #define HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) \ - HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromV) + HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromV) -#define HWY_IF_T_SIZE_V(V, bytes) HWY_IF_T_SIZE(TFromV, bytes) -#define HWY_IF_NOT_T_SIZE_V(V, bytes) HWY_IF_NOT_T_SIZE(TFromV, bytes) +#define HWY_IF_T_SIZE_V(V, bytes) \ + HWY_IF_T_SIZE(hwy::HWY_NAMESPACE::TFromV, bytes) +#define HWY_IF_NOT_T_SIZE_V(V, bytes) \ + HWY_IF_NOT_T_SIZE(hwy::HWY_NAMESPACE::TFromV, bytes) #define HWY_IF_T_SIZE_ONE_OF_V(V, bit_array) \ - HWY_IF_T_SIZE_ONE_OF(TFromV, bit_array) + HWY_IF_T_SIZE_ONE_OF(hwy::HWY_NAMESPACE::TFromV, bit_array) -#define HWY_MAX_LANES_V(V) HWY_MAX_LANES_D(DFromV) +#define HWY_MAX_LANES_V(V) HWY_MAX_LANES_D(hwy::HWY_NAMESPACE::DFromV) #define HWY_IF_V_SIZE_V(V, bytes) \ - HWY_IF_V_SIZE(TFromV, HWY_MAX_LANES_V(V), bytes) + HWY_IF_V_SIZE(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) #define HWY_IF_V_SIZE_LE_V(V, bytes) \ - HWY_IF_V_SIZE_LE(TFromV, HWY_MAX_LANES_V(V), bytes) + HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) #define HWY_IF_V_SIZE_GT_V(V, bytes) \ - HWY_IF_V_SIZE_GT(TFromV, HWY_MAX_LANES_V(V), bytes) + HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) + +// Use in implementations of ReduceSum etc. to avoid conflicts with the N=1 and +// N=4 8-bit specializations in generic_ops-inl. +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) \ + hwy::EnableIf) != 1)>* = nullptr + +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1) + +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1) + +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV, 1) + +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV, 1) + +// HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V is used to disable the default +// implementation of unsigned to signed DemoteTo/ReorderDemote2To in +// generic_ops-inl.h for at least some of the unsigned to signed demotions on +// SCALAR/EMU128/SSE2/SSSE3/SSE4/AVX2/SVE/SVE2 + +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) void* = nullptr // Old names (deprecated) #define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_T_SIZE_D(D, bytes) diff --git a/r/src/vendor/highway/hwy/ops/tuple-inl.h b/r/src/vendor/highway/hwy/ops/tuple-inl.h deleted file mode 100644 index 9def0610..00000000 --- a/r/src/vendor/highway/hwy/ops/tuple-inl.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tuple support. Included by those ops/* that lack native tuple types, after -// they define VFromD and before they use the tuples e.g. for LoadInterleaved2. -// Assumes we are already in the HWY_NAMESPACE and under an include guard. - -// If viewing this header standalone, define VFromD to avoid IDE warnings. -// This is normally set by set_macros-inl.h before this header is included. -#if !defined(HWY_NAMESPACE) -#include "hwy/base.h" -template -using VFromD = int; -#endif - -// On SVE, Vec2..4 are aliases to built-in types. -template -struct Vec2 { - VFromD v0; - VFromD v1; -}; - -template -struct Vec3 { - VFromD v0; - VFromD v1; - VFromD v2; -}; - -template -struct Vec4 { - VFromD v0; - VFromD v1; - VFromD v2; - VFromD v3; -}; - -// D arg is unused but allows deducing D. -template -HWY_API Vec2 Create2(D /* tag */, VFromD v0, VFromD v1) { - return Vec2{v0, v1}; -} - -template -HWY_API Vec3 Create3(D /* tag */, VFromD v0, VFromD v1, VFromD v2) { - return Vec3{v0, v1, v2}; -} - -template -HWY_API Vec4 Create4(D /* tag */, VFromD v0, VFromD v1, VFromD v2, - VFromD v3) { - return Vec4{v0, v1, v2, v3}; -} - -template -HWY_API VFromD Get2(Vec2 tuple) { - static_assert(kIndex < 2, "Tuple index out of bounds"); - return kIndex == 0 ? tuple.v0 : tuple.v1; -} - -template -HWY_API VFromD Get3(Vec3 tuple) { - static_assert(kIndex < 3, "Tuple index out of bounds"); - return kIndex == 0 ? tuple.v0 : kIndex == 1 ? tuple.v1 : tuple.v2; -} - -template -HWY_API VFromD Get4(Vec4 tuple) { - static_assert(kIndex < 4, "Tuple index out of bounds"); - return kIndex == 0 ? tuple.v0 - : kIndex == 1 ? tuple.v1 - : kIndex == 2 ? tuple.v2 - : tuple.v3; -} - -template -HWY_API Vec2 Set2(Vec2 tuple, VFromD val) { - static_assert(kIndex < 2, "Tuple index out of bounds"); - if (kIndex == 0) { - tuple.v0 = val; - } else { - tuple.v1 = val; - } - return tuple; -} - -template -HWY_API Vec3 Set3(Vec3 tuple, VFromD val) { - static_assert(kIndex < 3, "Tuple index out of bounds"); - if (kIndex == 0) { - tuple.v0 = val; - } else if (kIndex == 1) { - tuple.v1 = val; - } else { - tuple.v2 = val; - } - return tuple; -} - -template -HWY_API Vec4 Set4(Vec4 tuple, VFromD val) { - static_assert(kIndex < 4, "Tuple index out of bounds"); - if (kIndex == 0) { - tuple.v0 = val; - } else if (kIndex == 1) { - tuple.v1 = val; - } else if (kIndex == 2) { - tuple.v2 = val; - } else { - tuple.v3 = val; - } - return tuple; -} \ No newline at end of file diff --git a/r/src/vendor/highway/hwy/ops/wasm_128-inl.h b/r/src/vendor/highway/hwy/ops/wasm_128-inl.h index 824f90a1..39471d52 100644 --- a/r/src/vendor/highway/hwy/ops/wasm_128-inl.h +++ b/r/src/vendor/highway/hwy/ops/wasm_128-inl.h @@ -92,6 +92,9 @@ class Vec128 { HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } @@ -151,9 +154,6 @@ HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { template using VFromD = decltype(Zero(D())); -// ------------------------------ Tuple (VFromD) -#include "hwy/ops/tuple-inl.h" - // ------------------------------ BitCast namespace detail { @@ -213,25 +213,29 @@ template HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{wasm_i8x16_splat(static_cast(t))}; } -template +template HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{wasm_i16x8_splat(static_cast(t))}; } -template +template HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{wasm_i32x4_splat(static_cast(t))}; } -template +template HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{wasm_i64x2_splat(static_cast(t))}; } +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i16x8_splat(BitCastScalar(t))}; +} template -HWY_API VFromD Set(D /* tag */, const float t) { +HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{wasm_f32x4_splat(t)}; } template -HWY_API VFromD Set(D /* tag */, const double t) { +HWY_API VFromD Set(D /* tag */, TFromD t) { return VFromD{wasm_f64x2_splat(t)}; } @@ -251,12 +255,99 @@ template , typename T2> HWY_API VFromD Iota(D d, const T2 first) { HWY_ALIGN T lanes[MaxLanes(d)]; for (size_t i = 0; i < MaxLanes(d); ++i) { - lanes[i] = - AddWithWraparound(hwy::IsFloatTag(), static_cast(first), i); + lanes[i] = AddWithWraparound(static_cast(first), i); } return Load(d, lanes); } +// ------------------------------ Dup128VecFromValues +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{wasm_i8x16_make(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, + t11, t12, t13, t14, t15)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{wasm_u8x16_make(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, + t11, t12, t13, t14, t15)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{wasm_i16x8_make(t0, t1, t2, t3, t4, t5, t6, t7)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{wasm_u16x8_make(t0, t1, t2, t3, t4, t5, t6, t7)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_i32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_u32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_f32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_i64x2_make(t0, t1)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_u64x2_make(t0, t1)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_f64x2_make(t0, t1)}; +} + // ================================================== ARITHMETIC // ------------------------------ Addition @@ -447,6 +538,17 @@ HWY_API Vec128 AverageRound(const Vec128 a, return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; } +template +HWY_API V AverageRound(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const V sign_bit = SignBit(d); + return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)), + BitCast(du, Xor(b, sign_bit)))), + sign_bit); +} + // ------------------------------ Absolute value // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. @@ -560,12 +662,16 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { } // ------------------------------ RotateRight (ShiftRight, Or) -template +template HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; - return Or(ShiftRight(v), + return Or(BitCast(d, ShiftRight(BitCast(du, v))), ShiftLeft(v)); } @@ -823,7 +929,25 @@ HWY_API Vec128 operator*(const Vec128 a, return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; } -// Returns the upper 16 bits of a * b in each lane. +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u16x8_extmul_low_u8x16(a.raw, b.raw); + const auto h = wasm_u16x8_extmul_high_u8x16(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i8x16_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i16x8_extmul_low_i8x16(a.raw, b.raw); + const auto h = wasm_i16x8_extmul_high_i8x16(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i8x16_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} template HWY_API Vec128 MulHigh(const Vec128 a, const Vec128 b) { @@ -842,6 +966,22 @@ HWY_API Vec128 MulHigh(const Vec128 a, return Vec128{ wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; } +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u64x2_extmul_low_u32x4(a.raw, b.raw); + const auto h = wasm_u64x2_extmul_high_u32x4(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i32x4_shuffle(l, h, 1, 3, 5, 7)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i64x2_extmul_low_i32x4(a.raw, b.raw); + const auto h = wasm_i64x2_extmul_high_i32x4(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i32x4_shuffle(l, h, 1, 3, 5, 7)}; +} template HWY_API Vec128 MulFixedPoint15(Vec128 a, @@ -964,9 +1104,9 @@ HWY_API Vec128 operator/(const Vec128 a, return Vec128{wasm_f64x2_div(a.raw, b.raw)}; } -template -HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { - return Set(DFromV(), T{1.0}) / v; +template )> +HWY_API V ApproximateReciprocal(const V v) { + return Set(DFromV(), 1.0f) / v; } // Integer overload defined in generic_ops-inl.h. @@ -977,25 +1117,25 @@ HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { // ------------------------------ Floating-point multiply-add variants -template +template HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, Vec128 add) { return mul * x + add; } -template +template HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, Vec128 add) { return add - mul * x; } -template +template HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, Vec128 sub) { return mul * x - sub; } -template +template HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, Vec128 sub) { return Neg(mul) * x - sub; @@ -1014,10 +1154,10 @@ HWY_API Vec128 Sqrt(const Vec128 v) { } // Approximate reciprocal square root -template -HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { +template )> +HWY_API V ApproximateReciprocalSqrt(V v) { // TODO(eustas): find cheaper a way to calculate this. - return Set(DFromV(), T{1.0}) / Sqrt(v); + return Set(DFromV(), static_cast>(1.0)) / Sqrt(v); } // ------------------------------ Floating-point rounding @@ -1071,10 +1211,10 @@ HWY_API Mask128 IsNaN(const Vec128 v) { template HWY_API Mask128 IsInf(const Vec128 v) { const DFromV d; - const RebindToSigned di; - const VFromD vi = BitCast(di, v); + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); + return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. @@ -1528,13 +1668,6 @@ HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, return IfThenElse(MaskFromVec(v), yes, no); } -template -HWY_API Vec128 ZeroIfNegative(Vec128 v) { - const DFromV d; - const auto zero = Zero(d); - return IfThenElse(Mask128{(v > zero).raw}, v, zero); -} - // ------------------------------ Mask logical template @@ -1815,9 +1948,7 @@ template HWY_INLINE T ExtractLane(const Vec128 v) { const int16_t lane = wasm_i16x8_extract_lane(v.raw, kLane); - T ret; - CopySameSize(&lane, &ret); // for float16_t - return ret; + return static_cast(lane); } template @@ -1826,10 +1957,7 @@ HWY_INLINE T ExtractLane(const Vec128 v) { const RebindToUnsigned du; const uint16_t bits = ExtractLane(BitCast(du, v)); - - T ret; - CopySameSize(&bits, &ret); - return ret; + return BitCastScalar(bits); } template HWY_INLINE T ExtractLane(const Vec128 v) { @@ -2038,7 +2166,7 @@ template HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { static_assert(kLane < N, "Lane index out of bounds"); return Vec128{ - wasm_i16x8_replace_lane(v.raw, kLane, static_cast(t))}; + wasm_i16x8_replace_lane(v.raw, kLane, BitCastScalar(t))}; } template @@ -3002,6 +3130,13 @@ HWY_API Vec128 InterleaveLower(Vec128 a, return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; } +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, InterleaveLower(BitCast(du, a), BitCast(du, b))); +} + // Additional overload for the optional tag (all vector lengths). template HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { @@ -3061,6 +3196,19 @@ HWY_API Vec128 InterleaveUpper(Vec128 a, return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; } +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} + template HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { @@ -3710,6 +3858,50 @@ HWY_API Vec128 OddEven(const Vec128 a, return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; } +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 2, 18, 4, 20, 6, 22, + 8, 24, 10, 26, 12, 28, 14, 30)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 2, 10, 4, 12, 6, 14)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 2, 6)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i8x16_shuffle(a.raw, b.raw, 1, 17, 3, 19, 5, 21, 7, 23, + 9, 25, 11, 27, 13, 29, 15, 31)}; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i16x8_shuffle(a.raw, b.raw, 1, 9, 3, 11, 5, 13, 7, 15)}; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i32x4_shuffle(a.raw, b.raw, 1, 5, 3, 7)}; +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + // ------------------------------ OddEvenBlocks template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { @@ -3986,6 +4178,9 @@ HWY_API VFromD PromoteUpperTo(D d, V v) { return PromoteTo(d, UpperHalf(dh, v)); } +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + // ------------------------------ Demotions (full -> part w/ narrow lanes) template @@ -4035,15 +4230,6 @@ HWY_API VFromD DemoteTo(D du8, VFromD> v) { return DemoteTo(du8, BitCast(di16, Min(v, Set(du16, 0x7FFF)))); } -template -HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { - const Rebind di32; - const Rebind du32; // for logical shift right - const Rebind du16; - const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); - return BitCast(dbf16, DemoteTo(du16, bits_in_32)); -} - template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { return VFromD{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; @@ -4114,15 +4300,6 @@ HWY_API VFromD DemoteTo(D df32, VFromD> v) { return DemoteTo(df32, adj_f64_val); } -template >> -HWY_API VFromD ReorderDemote2To(D dbf16, V32 a, V32 b) { - const RebindToUnsigned du16; - const Repartition du32; - const VFromD b_in_even = ShiftRight<16>(BitCast(du32, b)); - return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); -} - // Specializations for partial vectors because i16x8_narrow_i32x4 sets lanes // above 2*N. template @@ -4469,12 +4646,6 @@ HWY_API VFromD OrderedDemote2To(D d, V a, V b) { return ReorderDemote2To(d, a, b); } -template >> -HWY_API VFromD OrderedDemote2To(D dbf16, V32 a, V32 b) { - const RebindToUnsigned du16; - return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); -} - // ------------------------------ ConvertTo template @@ -4644,11 +4815,19 @@ HWY_API VFromD ConvertTo(DU du, VFromD> v) { } // ------------------------------ NearestInt (Round) -template -HWY_API Vec128 NearestInt(const Vec128 v) { +template +HWY_API Vec128, N> NearestInt(const Vec128 v) { return ConvertTo(RebindToSigned>(), Round(v)); } +// ------------------------------ DemoteToNearestInt (Round) +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + // No single instruction, round then demote. + return DemoteTo(di32, Round(v)); +} + // ================================================== MISC // ------------------------------ SumsOf8 (ShiftRight, Add) @@ -4675,6 +4854,31 @@ HWY_API Vec128 SumsOf8(const Vec128 v) { return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); } +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV di8; + const RepartitionToWide di16; + const RepartitionToWide di32; + const RepartitionToWide di64; + const RebindToUnsigned du32; + const RebindToUnsigned du64; + using VI16 = VFromD; + + const VI16 vFDB97531 = ShiftRight<8>(BitCast(di16, v)); + const VI16 vECA86420 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, v))); + const VI16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VI16 sDC_zz_98_zz_54_zz_10_zz = + BitCast(di16, ShiftLeft<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VI16 sFC_xx_B8_xx_74_xx_30_xx = + Add(sFE_DC_BA_98_76_54_32_10, sDC_zz_98_zz_54_zz_10_zz); + const VI16 sB8_xx_zz_zz_30_xx_zz_zz = + BitCast(di16, ShiftLeft<32>(BitCast(du64, sFC_xx_B8_xx_74_xx_30_xx))); + const VI16 sF8_xx_xx_xx_70_xx_xx_xx = + Add(sFC_xx_B8_xx_74_xx_30_xx, sB8_xx_zz_zz_30_xx_zz_zz); + return ShiftRight<48>(BitCast(di64, sF8_xx_xx_xx_70_xx_xx_xx)); +} + // ------------------------------ LoadMaskBits (TestBit) namespace detail { @@ -4729,6 +4933,15 @@ HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { return detail::LoadMaskBits(d, mask_bits); } +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + // ------------------------------ Mask namespace detail { @@ -5593,59 +5806,47 @@ HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { // ------------------------------ MulEven/Odd (Load) -HWY_INLINE Vec128 MulEven(const Vec128 a, - const Vec128 b) { - alignas(16) uint64_t mul[2]; - mul[0] = - Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), - static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); - return Load(Full128(), mul); +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full128(), mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full128(), mul); } -HWY_INLINE Vec128 MulOdd(const Vec128 a, - const Vec128 b) { - alignas(16) uint64_t mul[2]; - mul[0] = - Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), - static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); - return Load(Full128(), mul); +// ------------------------------ I64/U64 MulHigh (GetLane) +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); } -// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(detail::ExtractLane<1>(a), detail::ExtractLane<1>(b), &hi_1); + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +// ------------------------------ WidenMulPairwiseAdd (MulAdd, PromoteEvenTo) // Generic for all vector lengths. -template >> -HWY_API VFromD WidenMulPairwiseAdd(D32 df32, V16 a, V16 b) { - const Rebind du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 - // Using shift/and instead of Zip leads to the odd/even order that - // RearrangeToOddPlusEven prefers. - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return Mul(BitCast(df32, ae), BitCast(df32, be)) + - Mul(BitCast(df32, ao), BitCast(df32, bo)); -} - -template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, - const VFromD sum0, - VFromD& sum1) { - const Rebind du32; - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 - // Using shift/and instead of Zip leads to the odd/even order that - // RearrangeToOddPlusEven prefers. - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); } // Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is @@ -5659,35 +5860,18 @@ HWY_API VFromD WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) { template >> HWY_API VFromD WidenMulPairwiseAdd(DU32 du32, VU16 a, VU16 b) { - const auto lo16_mask = Set(du32, 0x0000FFFFu); - - const auto a0 = And(BitCast(du32, a), lo16_mask); - const auto b0 = And(BitCast(du32, b), lo16_mask); - - const auto a1 = ShiftRight<16>(BitCast(du32, a)); - const auto b1 = ShiftRight<16>(BitCast(du32, b)); - - return MulAdd(a1, b1, a0 * b0); + return MulAdd(PromoteEvenTo(du32, a), PromoteEvenTo(du32, b), + Mul(PromoteOddTo(du32, a), PromoteOddTo(du32, b))); } -// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is -// safe. -template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 d, V16 a, V16 b, +HWY_API VFromD ReorderWidenMulAccumulate(D32 d32, V16 a, V16 b, const VFromD sum0, VFromD& /*sum1*/) { - return sum0 + WidenMulPairwiseAdd(d, a, b); -} - -// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is -// safe. -template >> -HWY_API VFromD ReorderWidenMulAccumulate(DU32 d, VU16 a, VU16 b, - const VFromD sum0, - VFromD& /*sum1*/) { - return sum0 + WidenMulPairwiseAdd(d, a, b); + return sum0 + WidenMulPairwiseAdd(d32, a, b); } // ------------------------------ RearrangeToOddPlusEven @@ -5711,120 +5895,7 @@ HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, // ------------------------------ Reductions -namespace detail { - -// N=1: no-op -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v) { - return v; -} - -// N=2 -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v10) { - const DFromV d; - return Add(v10, Reverse2(d, v10)); -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v10) { - const DFromV d; - return Min(v10, Reverse2(d, v10)); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v10) { - const DFromV d; - return Max(v10, Reverse2(d, v10)); -} - -// N=4 (only 16/32-bit, else >128-bit) -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Add(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Add(v03_12_12_03, v12_03_03_12); -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Min(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Min(v03_12_12_03, v12_03_03_12); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Max(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Max(v03_12_12_03, v12_03_03_12); -} - -// N=8 (only 16-bit, else >128-bit) -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Add(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Add(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Add(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Min(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Min(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Min(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Max(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Max(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Max(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} - -template -HWY_INLINE T ReduceSum(Vec128 v) { - return GetLane(SumOfLanes(v)); -} - -} // namespace detail - -template -HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { - return detail::SumOfLanes(v); -} -template -HWY_API TFromD ReduceSum(D /* tag */, VFromD v) { - return detail::ReduceSum(v); -} -template -HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { - return detail::MinOfLanes(v); -} -template -HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { - return detail::MaxOfLanes(v); -} +// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum. // ------------------------------ Lt128 diff --git a/r/src/vendor/highway/hwy/ops/wasm_256-inl.h b/r/src/vendor/highway/hwy/ops/wasm_256-inl.h index 8716cdf7..aab7105e 100644 --- a/r/src/vendor/highway/hwy/ops/wasm_256-inl.h +++ b/r/src/vendor/highway/hwy/ops/wasm_256-inl.h @@ -43,6 +43,9 @@ class Vec256 { HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } @@ -122,6 +125,50 @@ HWY_API VFromD Set(D d, const T2 t) { // Undefined, Iota defined in wasm_128. +// ------------------------------ Dup128VecFromValues +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t8, + t9, t10, t11, t12, t13, t14, t15); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1); + return ret; +} + // ================================================== ARITHMETIC template @@ -146,6 +193,13 @@ HWY_API Vec256 SumsOf8(const Vec256 v) { return ret; } +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + template HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { a.v0 = SaturatedAdd(a.v0, b.v0); @@ -160,7 +214,8 @@ HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { return a; } -template +template HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { a.v0 = AverageRound(a.v0, b.v0); a.v1 = AverageRound(a.v1, b.v1); @@ -191,12 +246,17 @@ HWY_API Vec256 ShiftRight(Vec256 v) { } // ------------------------------ RotateRight (ShiftRight, Or) -template +template HWY_API Vec256 RotateRight(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; - return Or(ShiftRight(v), ShiftLeft(v)); + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); } // ------------------------------ Shift lanes by same variable #bits @@ -262,8 +322,9 @@ HWY_API Vec256> MulEven(Vec256 a, const Vec256 b) { ret.v1 = MulEven(a.v1, b.v1); return ret; } -HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { - Vec256 ret; +template +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; ret.v0 = MulEven(a.v0, b.v0); ret.v1 = MulEven(a.v1, b.v1); return ret; @@ -277,8 +338,9 @@ HWY_API Vec256> MulOdd(Vec256 a, const Vec256 b) { ret.v1 = MulOdd(a.v1, b.v1); return ret; } -HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { - Vec256 ret; +template +HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { + Vec256 ret; ret.v0 = MulOdd(a.v0, b.v0); ret.v1 = MulOdd(a.v1, b.v1); return ret; @@ -300,44 +362,39 @@ HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { } // ------------------------------ Floating-point division -template +// generic_ops takes care of integer T. +template HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { a.v0 /= b.v0; a.v1 /= b.v1; return a; } -// Approximate reciprocal -HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { - const Vec256 one = Set(Full256(), 1.0f); - return one / v; -} - // ------------------------------ Floating-point multiply-add variants -HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, - Vec256 add) { +template +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, Vec256 add) { mul.v0 = MulAdd(mul.v0, x.v0, add.v0); mul.v1 = MulAdd(mul.v1, x.v1, add.v1); return mul; } -HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, - Vec256 add) { +template +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, Vec256 add) { mul.v0 = NegMulAdd(mul.v0, x.v0, add.v0); mul.v1 = NegMulAdd(mul.v1, x.v1, add.v1); return mul; } -HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, - Vec256 sub) { +template +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, Vec256 sub) { mul.v0 = MulSub(mul.v0, x.v0, sub.v0); mul.v1 = MulSub(mul.v1, x.v1, sub.v1); return mul; } -HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, - Vec256 sub) { +template +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, Vec256 sub) { mul.v0 = NegMulSub(mul.v0, x.v0, sub.v0); mul.v1 = NegMulSub(mul.v1, x.v1, sub.v1); return mul; @@ -352,38 +409,35 @@ HWY_API Vec256 Sqrt(Vec256 v) { return v; } -// Approximate reciprocal square root -HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { - // TODO(eustas): find cheaper a way to calculate this. - const Vec256 one = Set(Full256(), 1.0f); - return one / Sqrt(v); -} - // ------------------------------ Floating-point rounding // Toward nearest integer, ties to even -HWY_API Vec256 Round(Vec256 v) { +template +HWY_API Vec256 Round(Vec256 v) { v.v0 = Round(v.v0); v.v1 = Round(v.v1); return v; } // Toward zero, aka truncate -HWY_API Vec256 Trunc(Vec256 v) { +template +HWY_API Vec256 Trunc(Vec256 v) { v.v0 = Trunc(v.v0); v.v1 = Trunc(v.v1); return v; } // Toward +infinity, aka ceiling -HWY_API Vec256 Ceil(Vec256 v) { +template +HWY_API Vec256 Ceil(Vec256 v) { v.v0 = Ceil(v.v0); v.v1 = Ceil(v.v1); return v; } // Toward -infinity, aka floor -HWY_API Vec256 Floor(Vec256 v) { +template +HWY_API Vec256 Floor(Vec256 v) { v.v0 = Floor(v.v0); v.v1 = Floor(v.v1); return v; @@ -399,10 +453,10 @@ HWY_API Mask256 IsNaN(const Vec256 v) { template HWY_API Mask256 IsInf(const Vec256 v) { const DFromV d; - const RebindToSigned di; - const VFromD vi = BitCast(di, v); + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); + return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. @@ -630,11 +684,6 @@ HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { return v; } -template -HWY_API Vec256 ZeroIfNegative(Vec256 v) { - return IfThenZeroElse(v < Zero(DFromV()), v); -} - // ------------------------------ Mask logical template @@ -1177,6 +1226,26 @@ HWY_API Vec256 InterleaveUpper(D d, Vec256 a, Vec256 b) { return a; } +// ------------------------------ InterleaveWholeLower +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + ret.v0 = InterleaveLower(a.v0, b.v0); + ret.v1 = InterleaveUpper(dh, a.v0, b.v0); + return ret; +} + +// ------------------------------ InterleaveWholeUpper +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + ret.v0 = InterleaveLower(a.v1, b.v1); + ret.v1 = InterleaveUpper(dh, a.v1, b.v1); + return ret; +} + // ------------------------------ ZipLower/ZipUpper defined in wasm_128 // ================================================== COMBINE @@ -1293,6 +1362,24 @@ HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { return a; } +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Half dh; + a.v0 = InterleaveEven(dh, a.v0, b.v0); + a.v1 = InterleaveEven(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Half dh; + a.v0 = InterleaveOdd(dh, a.v0, b.v0); + a.v1 = InterleaveOdd(dh, a.v1, b.v1); + return a; +} + // ------------------------------ OddEvenBlocks template HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { @@ -1789,12 +1876,12 @@ HWY_API Vec128 DemoteTo(D d16, Vec256 v) { return Combine(d16, hi, lo); } -template -HWY_API Vec128 DemoteTo(D dbf16, Vec256 v) { - const Half dbf16h; - const Vec64 lo = DemoteTo(dbf16h, v.v0); - const Vec64 hi = DemoteTo(dbf16h, v.v1); - return Combine(dbf16, hi, lo); +template +HWY_API Vec128 DemoteTo(D df32, Vec256 v) { + const Half df32h; + const Vec64 lo = DemoteTo(df32h, v.v0); + const Vec64 hi = DemoteTo(df32h, v.v1); + return Combine(df32, hi, lo); } // For already range-limited input [0, 255]. @@ -1849,13 +1936,6 @@ HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { } // ------------------------------ ReorderDemote2To -template -HWY_API Vec256 ReorderDemote2To(DBF16 dbf16, Vec256 a, - Vec256 b) { - const RebindToUnsigned du16; - return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); -} - template ), HWY_IF_SIGNED_V(V), HWY_IF_T_SIZE_ONE_OF_D(DN, (1 << 1) | (1 << 2) | (1 << 4)), @@ -1891,8 +1971,9 @@ HWY_API Vec256 ConvertTo(DTo d, const Vec256 v) { return ret; } -HWY_API Vec256 NearestInt(const Vec256 v) { - return ConvertTo(Full256(), Round(v)); +template +HWY_API Vec256> NearestInt(const Vec256 v) { + return ConvertTo(Full256>(), Round(v)); } // ================================================== MISC @@ -1927,6 +2008,14 @@ HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { return ret; } +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + MFromD ret; + ret.m0 = ret.m1 = Dup128MaskFromMaskBits(dh, mask_bits); + return ret; +} + // ------------------------------ Mask // `p` points to at least 8 writable bytes. @@ -2302,34 +2391,7 @@ HWY_API Vec256 RearrangeToOddPlusEven(Vec256 sum0, Vec256 sum1) { return sum0; } -// ------------------------------ Reductions - -template > -HWY_API Vec256 SumOfLanes(D d, const Vec256 v) { - const Half dh; - const Vec128 lo = SumOfLanes(dh, Add(v.v0, v.v1)); - return Combine(d, lo, lo); -} - -template > -HWY_API T ReduceSum(D d, const Vec256 v) { - const Half dh; - return ReduceSum(dh, Add(v.v0, v.v1)); -} - -template > -HWY_API Vec256 MinOfLanes(D d, const Vec256 v) { - const Half dh; - const Vec128 lo = MinOfLanes(dh, Min(v.v0, v.v1)); - return Combine(d, lo, lo); -} - -template > -HWY_API Vec256 MaxOfLanes(D d, const Vec256 v) { - const Half dh; - const Vec128 lo = MaxOfLanes(dh, Max(v.v0, v.v1)); - return Combine(d, lo, lo); -} +// ------------------------------ Reductions in generic_ops // ------------------------------ Lt128 diff --git a/r/src/vendor/highway/hwy/ops/x86_128-inl.h b/r/src/vendor/highway/hwy/ops/x86_128-inl.h index 63e7da22..a863cd10 100644 --- a/r/src/vendor/highway/hwy/ops/x86_128-inl.h +++ b/r/src/vendor/highway/hwy/ops/x86_128-inl.h @@ -47,6 +47,29 @@ namespace hwy { namespace HWY_NAMESPACE { namespace detail { +// Enable generic functions for whichever of (f16, bf16) are not supported. +#if !HWY_HAVE_FLOAT16 +#define HWY_X86_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#else +#define HWY_X86_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#endif + +#undef HWY_AVX3_HAVE_F32_TO_BF16C +#if HWY_TARGET <= HWY_AVX3_ZEN4 && !HWY_COMPILER_CLANGCL && \ + (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 900) && \ + !defined(HWY_AVX3_DISABLE_AVX512BF16) +#define HWY_AVX3_HAVE_F32_TO_BF16C 1 +#else +#define HWY_AVX3_HAVE_F32_TO_BF16C 0 +#endif + +#undef HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 +#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "v" +#else +#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "x" +#endif + template struct Raw128 { using type = __m128i; @@ -90,6 +113,9 @@ class Vec128 { HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } @@ -194,18 +220,12 @@ template HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { return Vec128, HWY_MAX_LANES_D(D)>{_mm_setzero_si128()}; } -template -HWY_API Vec128 Zero(D /* tag */) { - return Vec128{_mm_setzero_si128()}; -} +#if HWY_HAVE_FLOAT16 template HWY_API Vec128 Zero(D /* tag */) { -#if HWY_HAVE_FLOAT16 return Vec128{_mm_setzero_ph()}; -#else - return Vec128{_mm_setzero_si128()}; -#endif } +#endif // HWY_HAVE_FLOAT16 template HWY_API Vec128 Zero(D /* tag */) { return Vec128{_mm_setzero_ps()}; @@ -214,15 +234,16 @@ template HWY_API Vec128 Zero(D /* tag */) { return Vec128{_mm_setzero_pd()}; } +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{_mm_setzero_si128()}; +} // Using the existing Zero function instead of a dedicated function for // deduction avoids having to forward-declare Vec256 here. template using VFromD = decltype(Zero(D())); -// ------------------------------ Tuple (VFromD) -#include "hwy/ops/tuple-inl.h" - // ------------------------------ BitCast namespace detail { @@ -234,6 +255,25 @@ HWY_INLINE __m128i BitCastToInteger(__m128h v) { return _mm_castph_si128(v); } HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m128i BitCastToInteger(__m128bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m128bh to a __m128i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m128bh vector + // to a __m128i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m128bh to a __m128i + return reinterpret_cast<__m128i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m128bh to a __m128i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one SSE/AVX vector type to a different SSE/AVX vector type + return BitCastScalar<__m128i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + template HWY_INLINE Vec128 BitCastToByte(Vec128 v) { return Vec128{BitCastToInteger(v.raw)}; @@ -307,7 +347,7 @@ HWY_API VFromD Set(D /* tag */, double t) { } // Generic for all vector lengths. -template +template HWY_API VFromD Set(D df, TFromD t) { const RebindToUnsigned du; static_assert(sizeof(TFromD) == 2, "Expecting [b]f16"); @@ -328,18 +368,12 @@ HWY_API VFromD Undefined(D /* tag */) { // generate an XOR instruction. return VFromD{_mm_undefined_si128()}; } -template -HWY_API VFromD Undefined(D /* tag */) { - return VFromD{_mm_undefined_si128()}; -} +#if HWY_HAVE_FLOAT16 template HWY_API VFromD Undefined(D /* tag */) { -#if HWY_HAVE_FLOAT16 return VFromD{_mm_undefined_ph()}; -#else - return VFromD{_mm_undefined_si128()}; -#endif } +#endif // HWY_HAVE_FLOAT16 template HWY_API VFromD Undefined(D /* tag */) { return VFromD{_mm_undefined_ps()}; @@ -348,6 +382,10 @@ template HWY_API VFromD Undefined(D /* tag */) { return VFromD{_mm_undefined_pd()}; } +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_si128()}; +} HWY_DIAGNOSTICS(pop) @@ -359,7 +397,11 @@ HWY_API T GetLane(const Vec128 v) { } template HWY_API T GetLane(const Vec128 v) { - return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFFFF); + const DFromV d; + const RebindToUnsigned du; + const uint16_t bits = + static_cast(_mm_cvtsi128_si32(BitCast(du, v).raw) & 0xFFFF); + return BitCastScalar(bits); } template HWY_API T GetLane(const Vec128 v) { @@ -394,6 +436,210 @@ HWY_API VFromD ResizeBitCast(D d, FromV v) { return BitCast(d, VFromD{detail::BitCastToInteger(v.raw)}); } +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{_mm_setr_epi8( + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{ + _mm_setr_epi16(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7))}; +} + +// Generic for all vector lengths +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7)}; +} +#else +// Generic for all vector lengths if HWY_HAVE_FLOAT16 is not true +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm_setr_ps(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + // Need to use _mm_set_epi64x as there is no _mm_setr_epi64x intrinsic + // available + return VFromD{ + _mm_set_epi64x(static_cast(t1), static_cast(t0))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm_setr_pd(t0, t1)}; +} + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<1> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<2> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<4> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<8> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<16> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]); +} + +#if HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<32> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]) && + __builtin_constant_p(v[16]) && __builtin_constant_p(v[17]) && + __builtin_constant_p(v[18]) && __builtin_constant_p(v[19]) && + __builtin_constant_p(v[20]) && __builtin_constant_p(v[21]) && + __builtin_constant_p(v[22]) && __builtin_constant_p(v[23]) && + __builtin_constant_p(v[24]) && __builtin_constant_p(v[25]) && + __builtin_constant_p(v[26]) && __builtin_constant_p(v[27]) && + __builtin_constant_p(v[28]) && __builtin_constant_p(v[29]) && + __builtin_constant_p(v[30]) && __builtin_constant_p(v[31]); +} +#endif + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantX86Vec( + hwy::SizeTag num_of_lanes_tag, V v) { + using T = TFromV; +#if HWY_HAVE_FLOAT16 && HWY_HAVE_SCALAR_F16_TYPE + using F16VecLaneT = hwy::float16_t::Native; +#else + using F16VecLaneT = uint16_t; +#endif + using RawVecLaneT = If(), F16VecLaneT, + If(), uint16_t, T>>; + + // Suppress the -Wignored-attributes warning that is emitted by + // RemoveCvRef with GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") + typedef RawVecLaneT GccRawVec + __attribute__((__vector_size__(sizeof(RemoveCvRef)))); + HWY_DIAGNOSTICS(pop) + + return IsConstantRawX86Vec(num_of_lanes_tag, + reinterpret_cast(v.raw)); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantX86VecForF2IConv(V v) { + constexpr size_t kNumOfLanesInRawSrcVec = + HWY_MAX(HWY_MAX_LANES_V(V), 16 / sizeof(TFromV)); + constexpr size_t kNumOfLanesInRawResultVec = + HWY_MAX(HWY_MAX_LANES_V(V), 16 / sizeof(TTo)); + constexpr size_t kNumOfLanesToCheck = + HWY_MIN(kNumOfLanesInRawSrcVec, kNumOfLanesInRawResultVec); + + return IsConstantX86Vec(hwy::SizeTag(), v); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + // ================================================== LOGICAL // ------------------------------ And @@ -402,7 +648,8 @@ template HWY_API Vec128 And(Vec128 a, Vec128 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm_and_si128(a.raw, b.raw)}); + return BitCast(d, VFromD{ + _mm_and_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); } template HWY_API Vec128 And(Vec128 a, Vec128 b) { @@ -420,8 +667,8 @@ template HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast( - d, VFromD{_mm_andnot_si128(not_mask.raw, mask.raw)}); + return BitCast(d, VFromD{_mm_andnot_si128( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); } template HWY_API Vec128 AndNot(Vec128 not_mask, @@ -440,7 +687,8 @@ template HWY_API Vec128 Or(Vec128 a, Vec128 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm_or_si128(a.raw, b.raw)}); + return BitCast(d, VFromD{ + _mm_or_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); } template @@ -458,7 +706,8 @@ template HWY_API Vec128 Xor(Vec128 a, Vec128 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm_xor_si128(a.raw, b.raw)}); + return BitCast(d, VFromD{ + _mm_xor_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); } template @@ -476,7 +725,7 @@ HWY_API Vec128 Not(const Vec128 v) { const DFromV d; const RebindToUnsigned du; using VU = VFromD; -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const __m128i vu = BitCast(du, v).raw; return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); #else @@ -487,7 +736,7 @@ HWY_API Vec128 Not(const Vec128 v) { // ------------------------------ Xor3 template HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -502,7 +751,7 @@ HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { // ------------------------------ Or3 template HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -517,7 +766,7 @@ HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { // ------------------------------ OrAnd template HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -533,7 +782,7 @@ HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { template HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, Vec128 no) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -546,7 +795,7 @@ HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, } // ------------------------------ BitwiseIfThenElse -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN #ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE #undef HWY_NATIVE_BITWISE_IF_THEN_ELSE @@ -651,8 +900,9 @@ HWY_INLINE Vec128 Neg(const Vec128 v) { } // ------------------------------ Floating-point Abs -template -HWY_API Vec128 Abs(const Vec128 v) { +// Generic for all vector lengths +template )> +HWY_API V Abs(V v) { const DFromV d; const RebindToSigned di; using TI = TFromD; @@ -691,407 +941,420 @@ HWY_API V CopySignToAbs(const V abs, const V sign) { // ================================================== MASK #if HWY_TARGET <= HWY_AVX3 - -// ------------------------------ IfThenElse - -// Returns mask ? b : a. +// ------------------------------ MaskFromVec namespace detail { -// Templates for signed/unsigned integer of a particular size. template -HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, - Mask128 mask, Vec128 yes, - Vec128 no) { - return Vec128{_mm_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi8_mask(v.raw)}; } template -HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, - Mask128 mask, Vec128 yes, - Vec128 no) { - return Vec128{_mm_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi16_mask(v.raw)}; } template -HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, - Mask128 mask, Vec128 yes, - Vec128 no) { - return Vec128{_mm_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi32_mask(v.raw)}; } template -HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, - Mask128 mask, Vec128 yes, - Vec128 no) { - return Vec128{_mm_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi64_mask(v.raw)}; } } // namespace detail template -HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, - Vec128 no) { - return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); } - +// There do not seem to be native floating-point versions of these instructions. #if HWY_HAVE_FLOAT16 template -HWY_API Vec128 IfThenElse(Mask128 mask, - Vec128 yes, - Vec128 no) { - return Vec128{_mm_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; } -#endif // HWY_HAVE_FLOAT16 - +#endif template -HWY_API Vec128 IfThenElse(Mask128 mask, - Vec128 yes, Vec128 no) { - return Vec128{_mm_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; } - template -HWY_API Vec128 IfThenElse(Mask128 mask, - Vec128 yes, - Vec128 no) { - return Vec128{_mm_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; } -namespace detail { +template +using MFromD = decltype(MaskFromVec(VFromD())); -template -HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, - Mask128 mask, Vec128 yes) { - return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; -} -template -HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, - Mask128 mask, Vec128 yes) { - return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; -} -template -HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, - Mask128 mask, Vec128 yes) { - return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; -} -template -HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, - Mask128 mask, Vec128 yes) { - return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; -} +// ------------------------------ MaskFalse (MFromD) -} // namespace detail +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif -template -HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { - return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +// Generic for all vector lengths +template +HWY_API MFromD MaskFalse(D /*d*/) { + return MFromD{static_cast().raw)>(0)}; } -template -HWY_API Vec128 IfThenElseZero(Mask128 mask, - Vec128 yes) { - return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; -} +// ------------------------------ IsNegative (MFromD) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif -template -HWY_API Vec128 IfThenElseZero(Mask128 mask, - Vec128 yes) { - return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +// Generic for all vector lengths +template +HWY_API MFromD> IsNegative(V v) { + return MaskFromVec(v); } -namespace detail { +// ------------------------------ PromoteMaskTo (MFromD) -template -HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, - Mask128 mask, Vec128 no) { - // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. - return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; -} -template -HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, - Mask128 mask, Vec128 no) { - return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; -} -template -HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, - Mask128 mask, Vec128 no) { - return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; -} -template -HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, - Mask128 mask, Vec128 no) { - return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +// AVX3 PromoteMaskTo is generic for all vector lengths +template )), + class DFrom_2 = Rebind, DTo>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return MFromD{static_cast().raw)>(m.raw)}; } -} // namespace detail +// ------------------------------ DemoteMaskTo (MFromD) -template -HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { - return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); -} +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif -template -HWY_API Vec128 IfThenZeroElse(Mask128 mask, - Vec128 no) { - return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +// AVX3 DemoteMaskTo is generic for all vector lengths +template ) - 1), + class DFrom_2 = Rebind, DTo>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return MFromD{static_cast().raw)>(m.raw)}; } -template -HWY_API Vec128 IfThenZeroElse(Mask128 mask, - Vec128 no) { - return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; -} +// ------------------------------ CombineMasks (MFromD) -// ------------------------------ Mask logical +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif -// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. -#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) -#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ - HWY_COMPILER_CLANG >= 800 -#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 1), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(1))); #else -#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 + const auto combined_mask = + (static_cast(hi.raw) << 1) | (lo.raw & 1); #endif -#endif // HWY_COMPILER_HAS_MASK_INTRINSICS -namespace detail { + return MFromD{static_cast().raw)>(combined_mask)}; +} -template -HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, - const Mask128 b) { +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kand_mask16(a.raw, b.raw)}; + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 2), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(3))); #else - return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; + const auto combined_mask = + (static_cast(hi.raw) << 2) | (lo.raw & 3); #endif + + return MFromD{static_cast().raw)>(combined_mask)}; } -template -HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kand_mask8(a.raw, b.raw)}; + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 4), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(15))); #else - return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; -#endif -} -template -HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kand_mask8(a.raw, b.raw)}; -#else - return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; -#endif -} -template -HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kand_mask8(a.raw, b.raw)}; -#else - return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; + const auto combined_mask = + (static_cast(hi.raw) << 4) | (lo.raw & 15u); #endif + + return MFromD{static_cast().raw)>(combined_mask)}; } -template -HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kandn_mask16(a.raw, b.raw)}; -#else - return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; -#endif -} -template -HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kandn_mask8(a.raw, b.raw)}; -#else - return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; -#endif -} -template -HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kandn_mask8(a.raw, b.raw)}; -#else - return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; -#endif -} -template -HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, - const Mask128 b) { +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kandn_mask8(a.raw, b.raw)}; + const __mmask16 combined_mask = _mm512_kunpackb( + static_cast<__mmask16>(hi.raw), static_cast<__mmask16>(lo.raw)); #else - return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; + const auto combined_mask = + ((static_cast(hi.raw) << 8) | (lo.raw & 0xFFu)); #endif + + return MFromD{static_cast().raw)>(combined_mask)}; } -template -HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kor_mask16(a.raw, b.raw)}; +// ------------------------------ LowerHalfOfMask (MFromD) + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK #else - return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#define HWY_NATIVE_LOWER_HALF_OF_MASK #endif + +// Generic for all vector lengths +template +HWY_API MFromD LowerHalfOfMask(D d, MFromD> m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumOfBitsInRawMask = sizeof(RawM) * 8; + + MFromD result_mask{static_cast(m.raw)}; + + if (kN < kNumOfBitsInRawMask) { + result_mask = + And(result_mask, MFromD{static_cast((1ULL << kN) - 1)}); + } + + return result_mask; } -template -HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kor_mask8(a.raw, b.raw)}; + +// ------------------------------ UpperHalfOfMask (MFromD) + +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK #else - return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#define HWY_NATIVE_UPPER_HALF_OF_MASK #endif -} -template -HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kor_mask8(a.raw, b.raw)}; + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 1); #else - return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; + const auto shifted_mask = static_cast(m.raw) >> 1; #endif + + return MFromD{static_cast().raw)>(shifted_mask)}; } -template -HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kor_mask8(a.raw, b.raw)}; + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 2); #else - return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; + const auto shifted_mask = static_cast(m.raw) >> 2; #endif + + return MFromD{static_cast().raw)>(shifted_mask)}; } -template -HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, - const Mask128 b) { +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kxor_mask16(a.raw, b.raw)}; + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 4); #else - return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; + const auto shifted_mask = static_cast(m.raw) >> 4; #endif + + return MFromD{static_cast().raw)>(shifted_mask)}; } -template -HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kxor_mask8(a.raw, b.raw)}; + const auto shifted_mask = _kshiftri_mask16(static_cast<__mmask16>(m.raw), 8); #else - return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; + const auto shifted_mask = static_cast(m.raw) >> 8; #endif + + return MFromD{static_cast().raw)>(shifted_mask)}; } -template -HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kxor_mask8(a.raw, b.raw)}; + +// ------------------------------ OrderedDemote2MasksTo (MFromD, CombineMasks) + +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO #else - return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO #endif + +// Generic for all vector lengths +template ) / 2), + class DTo_2 = Repartition, DFrom>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD OrderedDemote2MasksTo(DTo d_to, DFrom /*d_from*/, + MFromD a, MFromD b) { + using MH = MFromD>; + using RawMH = decltype(MH().raw); + + return CombineMasks(d_to, MH{static_cast(b.raw)}, + MH{static_cast(a.raw)}); } -template -HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, - const Mask128 b) { -#if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kxor_mask8(a.raw, b.raw)}; + +// ------------------------------ Slide mask up/down +#ifdef HWY_NATIVE_SLIDE_MASK +#undef HWY_NATIVE_SLIDE_MASK #else - return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#define HWY_NATIVE_SLIDE_MASK #endif -} -template -HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, - const Mask128 a, - const Mask128 b) { +template +HWY_API MFromD SlideMask1Up(D d, MFromD m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr unsigned kValidLanesMask = (1u << kN) - 1u; + #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kxnor_mask16(a.raw, b.raw)}; + MFromD result_mask{ + static_cast(_kshiftli_mask8(static_cast<__mmask8>(m.raw), 1))}; + + if (kN < 8) { + result_mask = + And(result_mask, MFromD{static_cast(kValidLanesMask)}); + } #else - return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; + MFromD result_mask{ + static_cast((static_cast(m.raw) << 1) & kValidLanesMask)}; #endif + + return result_mask; } -template -HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, - const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{_kxnor_mask8(a.raw, b.raw)}; + return MFromD{ + static_cast(_kshiftli_mask16(static_cast<__mmask16>(m.raw), 1))}; #else - return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; + return MFromD{static_cast(static_cast(m.raw) << 1)}; #endif } -template -HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, - const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD SlideMask1Down(D d, MFromD m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr unsigned kValidLanesMask = (1u << kN) - 1u; + #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; + if (kN < 8) { + m = And(m, MFromD{static_cast(kValidLanesMask)}); + } + + return MFromD{ + static_cast(_kshiftri_mask8(static_cast<__mmask8>(m.raw), 1))}; #else - return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; + return MFromD{ + static_cast((static_cast(m.raw) & kValidLanesMask) >> 1)}; #endif } -template -HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, - const Mask128 a, - const Mask128 b) { + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); #if HWY_COMPILER_HAS_MASK_INTRINSICS - return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; + return MFromD{ + static_cast(_kshiftri_mask16(static_cast<__mmask16>(m.raw), 1))}; #else - return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; + return MFromD{ + static_cast((static_cast(m.raw) & 0xFFFFu) >> 1)}; #endif } -} // namespace detail +// Generic for all vector lengths +template +HWY_API MFromD SlideMaskUpLanes(D d, MFromD m, size_t amt) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kValidLanesMask = + static_cast(((kN < 64) ? (1ULL << kN) : 0ULL) - 1ULL); -template -HWY_API Mask128 And(const Mask128 a, Mask128 b) { - return detail::And(hwy::SizeTag(), a, b); + return MFromD{static_cast( + (static_cast(m.raw) << (amt & 63)) & kValidLanesMask)}; } -template -HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { - return detail::AndNot(hwy::SizeTag(), a, b); -} +// Generic for all vector lengths +template +HWY_API MFromD SlideMaskDownLanes(D d, MFromD m, size_t amt) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kValidLanesMask = + static_cast(((kN < 64) ? (1ULL << kN) : 0ULL) - 1ULL); -template -HWY_API Mask128 Or(const Mask128 a, Mask128 b) { - return detail::Or(hwy::SizeTag(), a, b); + return MFromD{static_cast( + (static_cast(m.raw) & kValidLanesMask) >> (amt & 63))}; } -template -HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { - return detail::Xor(hwy::SizeTag(), a, b); -} +// ------------------------------ VecFromMask -template -HWY_API Mask128 Not(const Mask128 m) { - // Flip only the valid bits. - // TODO(janwas): use _knot intrinsics if N >= 8. - return Xor(m, Mask128::FromBits((1ull << N) - 1)); +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi8(v.raw)}; } -template -HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { - return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi16(v.raw)}; } -#else // AVX2 or below +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi32(v.raw)}; +} -// ------------------------------ Mask +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi64(v.raw)}; +} -// Mask and Vec are the same (true = FF..FF). -template -HWY_API Mask128 MaskFromVec(const Vec128 v) { - return Mask128{v.raw}; +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ph(_mm_movm_epi16(v.raw))}; } +#endif // HWY_HAVE_FLOAT16 -template -using MFromD = decltype(MaskFromVec(VFromD())); +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} -template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{v.raw}; +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; } // Generic for all vector lengths. @@ -1100,1573 +1363,1997 @@ HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { return VecFromMask(v); } -#if HWY_TARGET >= HWY_SSSE3 +// ------------------------------ RebindMask (MaskFromVec) -// mask ? yes : no -template -HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, - Vec128 no) { - const auto vmask = VecFromMask(DFromV(), mask); - return Or(And(vmask, yes), AndNot(vmask, no)); +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; } -#else // HWY_TARGET < HWY_SSSE3 +// ------------------------------ IfThenElse + +namespace detail { -// mask ? yes : no template -HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, - Vec128 no) { - return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; } -template -HWY_API Vec128 IfThenElse(Mask128 mask, - Vec128 yes, Vec128 no) { - return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; } -template -HWY_API Vec128 IfThenElse(Mask128 mask, - Vec128 yes, - Vec128 no) { - return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; } -#endif // HWY_TARGET >= HWY_SSSE3 +} // namespace detail -// mask ? yes : 0 -template -HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { - return yes & VecFromMask(DFromV(), mask); +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); } -// mask ? 0 : no -template -HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { - return AndNot(VecFromMask(DFromV(), mask), no); +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_ph(mask.raw, no.raw, yes.raw)}; } +#endif // HWY_HAVE_FLOAT16 -// ------------------------------ Mask logical +// Generic for all vector lengths. +template , HWY_X86_IF_EMULATED_D(D)> +HWY_API V IfThenElse(MFromD mask, V yes, V no) { + const RebindToUnsigned du; + return BitCast( + D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} -template -HWY_API Mask128 Not(const Mask128 m) { - const Simd d; - return MaskFromVec(Not(VecFromMask(d, m))); +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_mask_blend_ps(mask.raw, no.raw, yes.raw)}; } -template -HWY_API Mask128 And(const Mask128 a, Mask128 b) { - const Simd d; - return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_pd(mask.raw, no.raw, yes.raw)}; } +namespace detail { + template -HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { - const Simd d; - return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; } - template -HWY_API Mask128 Or(const Mask128 a, Mask128 b) { - const Simd d; - return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; } - template -HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { - const Simd d; - return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; } - template -HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { - const Simd d; - return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; } -#endif // HWY_TARGET <= HWY_AVX3 - -// ------------------------------ ShiftLeft +} // namespace detail -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{_mm_slli_epi16(v.raw, kBits)}; +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{_mm_slli_epi32(v.raw, kBits)}; +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; } -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{_mm_slli_epi64(v.raw, kBits)}; +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; } -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{_mm_slli_epi16(v.raw, kBits)}; -} -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{_mm_slli_epi32(v.raw, kBits)}; -} -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - return Vec128{_mm_slli_epi64(v.raw, kBits)}; +// Generic for all vector lengths. +template , HWY_IF_SPECIAL_FLOAT_D(D)> +HWY_API V IfThenElseZero(MFromD mask, V yes) { + const RebindToUnsigned du; + return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); } -#if HWY_TARGET <= HWY_AVX3_DL - namespace detail { -template -HWY_API Vec128 GaloisAffine( - Vec128 v, VFromD>> matrix) { - return Vec128{_mm_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; -} -} // namespace detail - -#else // HWY_TARGET > HWY_AVX3_DL -template -HWY_API Vec128 ShiftLeft(const Vec128 v) { - const DFromV d8; - // Use raw instead of BitCast to support N=1. - const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; - return kBits == 1 - ? (v + v) - : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); -} - -#endif // HWY_TARGET > HWY_AVX3_DL - -// ------------------------------ ShiftRight - -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{_mm_srli_epi16(v.raw, kBits)}; -} -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{_mm_srli_epi32(v.raw, kBits)}; +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{_mm_srli_epi64(v.raw, kBits)}; +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; } - -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{_mm_srai_epi16(v.raw, kBits)}; +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - return Vec128{_mm_srai_epi32(v.raw, kBits)}; +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; } -#if HWY_TARGET > HWY_AVX3_DL +} // namespace detail -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - const DFromV d8; - // Use raw instead of BitCast to support N=1. - const Vec128 shifted{ - ShiftRight(Vec128{v.raw}).raw}; - return shifted & Set(d8, 0xFF >> kBits); +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } -template -HWY_API Vec128 ShiftRight(const Vec128 v) { - const DFromV di; - const RebindToUnsigned du; - const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); - const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); - return (shifted ^ shifted_sign) - shifted_sign; +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; } -#endif // HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} -// i64 is implemented after BroadcastSignBit. +// Generic for all vector lengths. +template , HWY_IF_SPECIAL_FLOAT_D(D)> +HWY_API V IfThenZeroElse(MFromD mask, V no) { + const RebindToUnsigned du; + return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} -// ================================================== MEMORY (1) +// ------------------------------ Mask logical -// Clang static analysis claims the memory immediately after a partial vector -// store is uninitialized, and also flags the input to partial loads (at least -// for loadl_pd) as "garbage". This is a false alarm because msan does not -// raise errors. We work around this by using CopyBytes instead of intrinsics, -// but only for the analyzer to avoid potentially bad code generation. -// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. -#ifndef HWY_SAFE_PARTIAL_LOAD_STORE -#if defined(__clang_analyzer__) || \ - (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) -#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 #else -#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 #endif -#endif // HWY_SAFE_PARTIAL_LOAD_STORE +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS -// ------------------------------ Load +namespace detail { -template -HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { - return VFromD{_mm_load_si128(reinterpret_cast(aligned))}; -} -// Generic for all vector lengths greater than or equal to 16 bytes. -template -HWY_API VFromD Load(D d, const bfloat16_t* HWY_RESTRICT aligned) { - const RebindToUnsigned du; - return BitCast(d, Load(du, reinterpret_cast(aligned))); +template +HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; +#endif } -template -HWY_API Vec128 Load(D d, const float16_t* HWY_RESTRICT aligned) { -#if HWY_HAVE_FLOAT16 - return Vec128{_mm_load_ph(aligned)}; +template +HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; #else - const RebindToUnsigned du; - return BitCast(d, Load(du, reinterpret_cast(aligned))); -#endif // HWY_HAVE_FLOAT16 + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif } -template -HWY_API Vec128 Load(D /* tag */, const float* HWY_RESTRICT aligned) { - return Vec128{_mm_load_ps(aligned)}; -} -template -HWY_API Vec128 Load(D /* tag */, const double* HWY_RESTRICT aligned) { - return Vec128{_mm_load_pd(aligned)}; -} - -template -HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_loadu_si128(reinterpret_cast(p))}; -} -// Generic for all vector lengths greater than or equal to 16 bytes. -template -HWY_API VFromD LoadU(D d, const bfloat16_t* HWY_RESTRICT p) { - const RebindToUnsigned du; - return BitCast(d, LoadU(du, reinterpret_cast(p))); -} -template -HWY_API Vec128 LoadU(D d, const float16_t* HWY_RESTRICT p) { -#if HWY_HAVE_FLOAT16 - (void)d; - return Vec128{_mm_loadu_ph(p)}; +template +HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; #else - const RebindToUnsigned du; - return BitCast(d, LoadU(du, reinterpret_cast(p))); -#endif // HWY_HAVE_FLOAT16 -} -template -HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT p) { - return Vec128{_mm_loadu_ps(p)}; -} -template -HWY_API Vec128 LoadU(D /* tag */, const double* HWY_RESTRICT p) { - return Vec128{_mm_loadu_pd(p)}; + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif } - -template -HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { - const RebindToUnsigned du; // for float16_t -#if HWY_SAFE_PARTIAL_LOAD_STORE - __m128i v = _mm_setzero_si128(); - CopyBytes<8>(p, &v); // not same size +template +HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; #else - const __m128i v = _mm_loadl_epi64(reinterpret_cast(p)); + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; #endif - return BitCast(d, VFromD{v}); } -template -HWY_API Vec64 Load(D /* tag */, const float* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - __m128 v = _mm_setzero_ps(); - CopyBytes<8>(p, &v); // not same size - return Vec64{v}; +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask16(a.raw, b.raw)}; #else - const __m128 hi = _mm_setzero_ps(); - return Vec64{_mm_loadl_pi(hi, reinterpret_cast(p))}; + return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; #endif } - -template -HWY_API Vec64 Load(D /* tag */, const double* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - __m128d v = _mm_setzero_pd(); - CopyBytes<8>(p, &v); // not same size - return Vec64{v}; +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; #else - return Vec64{_mm_load_sd(p)}; + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; #endif } - -template -HWY_API Vec32 Load(D /* tag */, const float* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - __m128 v = _mm_setzero_ps(); - CopyBytes<4>(p, &v); // not same size - return Vec32{v}; +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; #else - return Vec32{_mm_load_ss(p)}; + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; #endif } - -// Any <= 32 bit except -template -HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { - const RebindToUnsigned du; // for float16_t - // Clang ArgumentPromotionPass seems to break this code. We can unpoison - // before SetTableIndices -> LoadU -> Load and the memory is poisoned again. - detail::MaybeUnpoison(p, Lanes(d)); - -#if HWY_SAFE_PARTIAL_LOAD_STORE - __m128i v = Zero(Full128>()).raw; - CopyBytes(p, &v); // not same size as VFromD +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; #else - int32_t bits = 0; - CopyBytes(p, &bits); // not same size as VFromD - const __m128i v = _mm_cvtsi32_si128(bits); + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; #endif - return BitCast(d, VFromD{v}); -} - -// For < 128 bit, LoadU == Load. -template -HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { - return Load(d, p); -} - -// 128-bit SIMD => nothing to duplicate, same as an unaligned load. -template -HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { - return LoadU(d, p); } -// ------------------------------ Store - -template -HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { - _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); -} -// Generic for all vector lengths greater than or equal to 16 bytes. -template -HWY_API void Store(VFromD v, D d, bfloat16_t* HWY_RESTRICT aligned) { - const RebindToUnsigned du; - Store(BitCast(du, v), du, reinterpret_cast(aligned)); +template +HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#endif } -template -HWY_API void Store(Vec128 v, D d, float16_t* HWY_RESTRICT aligned) { -#if HWY_HAVE_FLOAT16 - (void)d; - _mm_store_ph(aligned, v.raw); +template +HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; #else - const RebindToUnsigned du; - Store(BitCast(du, v), du, reinterpret_cast(aligned)); -#endif // HWY_HAVE_FLOAT16 + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif } -template -HWY_API void Store(Vec128 v, D /* tag */, float* HWY_RESTRICT aligned) { - _mm_store_ps(aligned, v.raw); +template +HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif } -template -HWY_API void Store(Vec128 v, D /* tag */, - double* HWY_RESTRICT aligned) { - _mm_store_pd(aligned, v.raw); +template +HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif } -template -HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { - _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); -} -// Generic for all vector lengths greater than or equal to 16 bytes. -template -HWY_API void StoreU(VFromD v, D d, bfloat16_t* HWY_RESTRICT p) { - const RebindToUnsigned du; - StoreU(BitCast(du, v), du, reinterpret_cast(p)); -} -template -HWY_API void StoreU(Vec128 v, D d, float16_t* HWY_RESTRICT p) { -#if HWY_HAVE_FLOAT16 - (void)d; - _mm_storeu_ph(p, v.raw); +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask16(a.raw, b.raw)}; #else - const RebindToUnsigned du; - StoreU(BitCast(du, v), du, reinterpret_cast(p)); -#endif // HWY_HAVE_FLOAT16 -} -template -HWY_API void StoreU(Vec128 v, D /* tag */, float* HWY_RESTRICT p) { - _mm_storeu_ps(p, v.raw); -} -template -HWY_API void StoreU(Vec128 v, D /* tag */, double* HWY_RESTRICT p) { - _mm_storeu_pd(p, v.raw); + return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif } - -template -HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - (void)d; - CopyBytes<8>(&v, p); // not same size +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; #else - const RebindToUnsigned du; // for float16_t - _mm_storel_epi64(reinterpret_cast<__m128i*>(p), BitCast(du, v).raw); + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } -template -HWY_API void Store(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - CopyBytes<8>(&v, p); // not same size +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; #else - _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } -template -HWY_API void Store(Vec64 v, D /* tag */, double* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - CopyBytes<8>(&v, p); // not same size +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; #else - _mm_storel_pd(p, v.raw); + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } -// Any <= 32 bit except -template -HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { - CopyBytes(&v, p); // not same size -} -template -HWY_API void Store(Vec32 v, D /* tag */, float* HWY_RESTRICT p) { -#if HWY_SAFE_PARTIAL_LOAD_STORE - CopyBytes<4>(&v, p); // not same size +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask16(a.raw, b.raw)}; #else - _mm_store_ss(p, v.raw); + return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; #endif } - -// For < 128 bit, StoreU == Store. -template -HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { - Store(v, d, p); -} - -// ================================================== SWIZZLE (1) - -// ------------------------------ TableLookupBytes -template -HWY_API Vec128 TableLookupBytes(const Vec128 bytes, - const Vec128 from) { -#if HWY_TARGET == HWY_SSE2 -#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) - typedef uint8_t GccU8RawVectType __attribute__((__vector_size__(16))); - return Vec128{reinterpret_cast::type>( - __builtin_shuffle(reinterpret_cast(bytes.raw), - reinterpret_cast(from.raw)))}; +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask8(a.raw, b.raw)}; #else - const DFromV d; - const Repartition du8; - const Full128 du8_full; - - const DFromV d_bytes; - const Repartition du8_bytes; - - alignas(16) uint8_t result_bytes[16]; - alignas(16) uint8_t u8_bytes[16]; - alignas(16) uint8_t from_bytes[16]; - - Store(Vec128{BitCast(du8_bytes, bytes).raw}, du8_full, u8_bytes); - Store(Vec128{BitCast(du8, from).raw}, du8_full, from_bytes); - - for (int i = 0; i < 16; i++) { - result_bytes[i] = u8_bytes[from_bytes[i] & 15]; - } - - return BitCast(d, VFromD{Load(du8_full, result_bytes).raw}); -#endif -#else // SSSE3 or newer - return Vec128{_mm_shuffle_epi8(bytes.raw, from.raw)}; + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; #endif } - -// ------------------------------ TableLookupBytesOr0 -// For all vector widths; x86 anyway zeroes if >= 0x80 on SSSE3/SSE4/AVX2/AVX3 -template -HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { -#if HWY_TARGET == HWY_SSE2 - const DFromV d; - const Repartition di8; - - const auto di8_from = BitCast(di8, from); - return BitCast(d, IfThenZeroElse(di8_from < Zero(di8), - TableLookupBytes(bytes, di8_from))); +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; #else - return TableLookupBytes(bytes, from); + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; #endif } - -// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) - -// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). -// Shuffle0321 rotates one lane to the right (the previous least-significant -// lane is now most-significant). These could also be implemented via -// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. - -// Swap 32-bit halves in 64-bit halves. template -HWY_API Vec128 Shuffle2301(const Vec128 v) { - static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); - static_assert(N == 2 || N == 4, "Does not make sense for N=1"); - return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; -} -template -HWY_API Vec128 Shuffle2301(const Vec128 v) { - static_assert(N == 2 || N == 4, "Does not make sense for N=1"); - return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; -} - -// These are used by generic_ops-inl to implement LoadInterleaved3. As with -// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output -// comes from the first argument. -namespace detail { - -template -HWY_API Vec32 ShuffleTwo2301(const Vec32 a, const Vec32 b) { - const DFromV d; - const Twice d2; - const auto ba = Combine(d2, b, a); -#if HWY_TARGET == HWY_SSE2 - Vec32 ba_shuffled{ - _mm_shufflelo_epi16(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; - return BitCast(d, Or(ShiftLeft<8>(ba_shuffled), ShiftRight<8>(ba_shuffled))); +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; #else - alignas(16) const T kShuffle[8] = {1, 0, 7, 6}; - return Vec32{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; #endif } -template -HWY_API Vec64 ShuffleTwo2301(const Vec64 a, const Vec64 b) { - const DFromV d; - const Twice d2; - const auto ba = Combine(d2, b, a); -#if HWY_TARGET == HWY_SSE2 - Vec64 ba_shuffled{ - _mm_shuffle_epi32(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; - return Vec64{ - _mm_shufflelo_epi16(ba_shuffled.raw, _MM_SHUFFLE(2, 3, 0, 1))}; + +// UnmaskedNot returns ~m.raw without zeroing out any invalid bits +template +HWY_INLINE Mask128 UnmaskedNot(const Mask128 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask16>(_knot_mask16(m.raw))}; #else - alignas(16) const T kShuffle[8] = {0x0302, 0x0100, 0x0f0e, 0x0d0c}; - return Vec64{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; + return Mask128{static_cast<__mmask16>(~m.raw)}; #endif } -template -HWY_API Vec128 ShuffleTwo2301(const Vec128 a, const Vec128 b) { - const DFromV d; - const RebindToFloat df; - constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); - return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, - BitCast(df, b).raw, m)}); -} -template -HWY_API Vec32 ShuffleTwo1230(const Vec32 a, const Vec32 b) { - const DFromV d; -#if HWY_TARGET == HWY_SSE2 - const auto zero = Zero(d); - const Rebind di16; - const Vec32 a_shuffled{_mm_shufflelo_epi16( - _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; - const Vec32 b_shuffled{_mm_shufflelo_epi16( - _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; - const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); - return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +template +HWY_INLINE Mask128 UnmaskedNot(const Mask128 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_knot_mask8(m.raw))}; #else - const Twice d2; - const auto ba = Combine(d2, b, a); - alignas(16) const T kShuffle[8] = {0, 3, 6, 5}; - return Vec32{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; + return Mask128{static_cast<__mmask8>(~m.raw)}; #endif } -template -HWY_API Vec64 ShuffleTwo1230(const Vec64 a, const Vec64 b) { - const DFromV d; -#if HWY_TARGET == HWY_SSE2 - const Vec32 a_shuffled{ - _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(3, 0, 3, 0))}; - const Vec32 b_shuffled{ - _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(1, 2, 1, 2))}; - return Combine(d, b_shuffled, a_shuffled); -#else - const Twice d2; - const auto ba = Combine(d2, b, a); - alignas(16) const T kShuffle[8] = {0x0100, 0x0706, 0x0d0c, 0x0b0a}; - return Vec64{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; -#endif + +template +HWY_INLINE Mask128 Not(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + // sizeof(T) == 1 and N == 16: simply return ~m as all 16 bits of m are valid + return UnmaskedNot(m); } -template -HWY_API Vec128 ShuffleTwo1230(const Vec128 a, const Vec128 b) { - const DFromV d; - const RebindToFloat df; - constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); - return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, - BitCast(df, b).raw, m)}); +template +HWY_INLINE Mask128 Not(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + // sizeof(T) == 1 and N <= 8: need to zero out the upper bits of ~m as there + // are fewer than 16 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<1>(), m, Mask128::FromBits((1ull << N) - 1)); } +template +HWY_INLINE Mask128 Not(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + // sizeof(T) == 2 and N == 8: simply return ~m as all 8 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + // sizeof(T) == 2 and N <= 4: need to zero out the upper bits of ~m as there + // are fewer than 8 valid bits in m -template -HWY_API Vec32 ShuffleTwo3012(const Vec32 a, const Vec32 b) { - const DFromV d; -#if HWY_TARGET == HWY_SSE2 - const auto zero = Zero(d); - const Rebind di16; - const Vec32 a_shuffled{_mm_shufflelo_epi16( - _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; - const Vec32 b_shuffled{_mm_shufflelo_epi16( - _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; - const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); - return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; -#else - const Twice d2; - const auto ba = Combine(d2, b, a); - alignas(16) const T kShuffle[8] = {2, 1, 4, 7}; - return Vec32{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; -#endif + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<2>(), m, Mask128::FromBits((1ull << N) - 1)); } -template -HWY_API Vec64 ShuffleTwo3012(const Vec64 a, const Vec64 b) { - const DFromV d; -#if HWY_TARGET == HWY_SSE2 - const Vec32 a_shuffled{ - _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(1, 2, 1, 2))}; - const Vec32 b_shuffled{ - _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(3, 0, 3, 0))}; - return Combine(d, b_shuffled, a_shuffled); -#else - const Twice d2; - const auto ba = Combine(d2, b, a); - alignas(16) const T kShuffle[8] = {0x0504, 0x0302, 0x0908, 0x0f0e}; - return Vec64{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; -#endif +template +HWY_INLINE Mask128 Not(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + // sizeof(T) == 4: need to zero out the upper bits of ~m as there are at most + // 4 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<4>(), m, Mask128::FromBits((1ull << N) - 1)); } -template -HWY_API Vec128 ShuffleTwo3012(const Vec128 a, const Vec128 b) { - const DFromV d; - const RebindToFloat df; - constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); - return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, - BitCast(df, b).raw, m)}); +template +HWY_INLINE Mask128 Not(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + // sizeof(T) == 8: need to zero out the upper bits of ~m as there are at most + // 2 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<8>(), m, Mask128::FromBits((1ull << N) - 1)); } } // namespace detail -// Swap 64-bit halves -HWY_API Vec128 Shuffle1032(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; -} -HWY_API Vec128 Shuffle1032(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + return detail::And(hwy::SizeTag(), a, b); } -HWY_API Vec128 Shuffle1032(const Vec128 v) { - return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + return detail::AndNot(hwy::SizeTag(), a, b); } -HWY_API Vec128 Shuffle01(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + return detail::Or(hwy::SizeTag(), a, b); } -HWY_API Vec128 Shuffle01(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + return detail::Xor(hwy::SizeTag(), a, b); } -HWY_API Vec128 Shuffle01(const Vec128 v) { - return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; + +template +HWY_API Mask128 Not(const Mask128 m) { + // Flip only the valid bits + return detail::Not(hwy::SizeTag(), m); } -// Rotate right 32 bits -HWY_API Vec128 Shuffle0321(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); } -HWY_API Vec128 Shuffle0321(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; -} -HWY_API Vec128 Shuffle0321(const Vec128 v) { - return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; -} -// Rotate left 32 bits -HWY_API Vec128 Shuffle2103(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; -} -HWY_API Vec128 Shuffle2103(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; -} -HWY_API Vec128 Shuffle2103(const Vec128 v) { - return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; -} - -// Reverse -HWY_API Vec128 Shuffle0123(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; -} -HWY_API Vec128 Shuffle0123(const Vec128 v) { - return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; -} -HWY_API Vec128 Shuffle0123(const Vec128 v) { - return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; -} - -// ================================================== COMPARE - -#if HWY_TARGET <= HWY_AVX3 - -// Comparisons set a mask bit to 1 if the condition is true, else 0. - -// ------------------------------ MaskFromVec - -namespace detail { -template -HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, - const Vec128 v) { - return Mask128{_mm_movepi8_mask(v.raw)}; -} -template -HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, - const Vec128 v) { - return Mask128{_mm_movepi16_mask(v.raw)}; -} -template -HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, - const Vec128 v) { - return Mask128{_mm_movepi32_mask(v.raw)}; -} -template -HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, - const Vec128 v) { - return Mask128{_mm_movepi64_mask(v.raw)}; -} +#else // AVX2 or below -} // namespace detail +// ------------------------------ Mask +// Mask and Vec are the same (true = FF..FF). template HWY_API Mask128 MaskFromVec(const Vec128 v) { - return detail::MaskFromVec(hwy::SizeTag(), v); -} -// There do not seem to be native floating-point versions of these instructions. -template -HWY_API Mask128 MaskFromVec(const Vec128 v) { - const RebindToSigned> di; - return Mask128{MaskFromVec(BitCast(di, v)).raw}; -} -template -HWY_API Mask128 MaskFromVec(const Vec128 v) { - const RebindToSigned> di; - return Mask128{MaskFromVec(BitCast(di, v)).raw}; + return Mask128{v.raw}; } template using MFromD = decltype(MaskFromVec(VFromD())); -// ------------------------------ VecFromMask - -template +template HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_movm_epi8(v.raw)}; + return Vec128{v.raw}; } -template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_movm_epi16(v.raw)}; +// Generic for all vector lengths. +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VecFromMask(v); } -template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_movm_epi32(v.raw)}; -} +#if HWY_TARGET >= HWY_SSSE3 -template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_movm_epi64(v.raw)}; +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const auto vmask = VecFromMask(DFromV(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); } -#if HWY_HAVE_FLOAT16 -template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_castsi128_ph(_mm_movm_epi16(v.raw))}; -} -#endif // HWY_HAVE_FLOAT16 +#else // HWY_TARGET < HWY_SSSE3 -template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; } - template -HWY_API Vec128 VecFromMask(const Mask128 v) { - return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; } - -// Generic for all vector lengths. -template -HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { - return VecFromMask(v); +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; } -// ------------------------------ RebindMask (MaskFromVec) +#endif // HWY_TARGET >= HWY_SSSE3 -template -HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { - static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); - return MFromD{m.raw}; +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); } -// ------------------------------ TestBit +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} -namespace detail { +// ------------------------------ Mask logical template -HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, - const Vec128 bit) { - return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); } + template -HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, - const Vec128 bit) { - return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } + template -HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, - const Vec128 bit) { - return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } + template -HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, - const Vec128 bit) { - return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } -} // namespace detail +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} template -HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { - static_assert(!hwy::IsFloat(), "Only integer vectors supported"); - return detail::TestBit(hwy::SizeTag(), v, bit); +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } -// ------------------------------ Equality +#endif // HWY_TARGET <= HWY_AVX3 -template -HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; -} +// ------------------------------ ShiftLeft -template -HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; } -template -HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; } -template -HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; } -#if HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; } -#endif // HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator==(Vec128 a, Vec128 b) { - return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; } - -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; } -// ------------------------------ Inequality +#if HWY_TARGET <= HWY_AVX3_DL -template -HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +namespace detail { +template +HWY_API Vec128 GaloisAffine( + Vec128 v, VFromD>> matrix) { + return Vec128{_mm_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; } +} // namespace detail -template -HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; -} +#else // HWY_TARGET > HWY_AVX3_DL -template -HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; -} - -template -HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { - return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; -} - -#if HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; -} -#endif // HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { - return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; -} +#endif // HWY_TARGET > HWY_AVX3_DL -// ------------------------------ Strict inequality +// ------------------------------ ShiftRight -// Signed/float < -template -HWY_API Mask128 operator>(Vec128 a, Vec128 b) { - return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; } -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; } -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; } -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; } -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; } -#if HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator>(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; -} -#endif // HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator>(Vec128 a, Vec128 b) { - return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); } -template -HWY_API Mask128 operator>(Vec128 a, Vec128 b) { - return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; } -// ------------------------------ Weak inequality +#endif // HWY_TARGET > HWY_AVX3_DL + +// i64 is implemented after BroadcastSignBit. + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{_mm_load_si128(reinterpret_cast(aligned))}; +} #if HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; +template +HWY_API Vec128 Load(D, const float16_t* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ph(aligned)}; } #endif // HWY_HAVE_FLOAT16 -template -HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { - return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; -} -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; -} - -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epi8_mask(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epi16_mask(a.raw, b.raw)}; +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + return BitCast(d, Load(du, detail::U16LanePointer(aligned))); } -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epi32_mask(a.raw, b.raw)}; +template +HWY_API Vec128 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; } -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epi64_mask(a.raw, b.raw)}; +template +HWY_API Vec128 Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; } -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epu8_mask(a.raw, b.raw)}; +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_loadu_si128(reinterpret_cast(p))}; } -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epu16_mask(a.raw, b.raw)}; +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 LoadU(D, const float16_t* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ph(p)}; } -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epu32_mask(a.raw, b.raw)}; +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); } -template -HWY_API Mask128 operator>=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_epu64_mask(a.raw, b.raw)}; +template +HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +template +HWY_API Vec128 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; } -#else // AVX2 or below - -// Comparisons fill a lane with 1-bits if the condition is true, else 0. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size +#else + const __m128i v = _mm_loadl_epi64(reinterpret_cast(p)); +#endif + return BitCast(d, VFromD{v}); +} -template -HWY_API MFromD RebindMask(DTo dto, Mask128 m) { - static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); - const Simd d; - return MaskFromVec(BitCast(dto, VecFromMask(d, m))); +template +HWY_API Vec64 Load(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec64{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif } -template -HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { - static_assert(!hwy::IsFloat(), "Only integer vectors supported"); - return (v & bit) == bit; +template +HWY_API Vec64 Load(D /* tag */, const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_load_sd(p)}; +#endif } -// ------------------------------ Equality - -// Unsigned -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { -#if HWY_TARGET >= HWY_SSSE3 - const DFromV d64; - const RepartitionToNarrow d32; - const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); - const auto cmp64 = cmp32 & Shuffle2301(cmp32); - return MaskFromVec(BitCast(d64, cmp64)); +template +HWY_API Vec32 Load(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec32{v}; #else - return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; + return Vec32{_mm_load_ss(p)}; #endif } -// Signed -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; -} -template -HWY_API Mask128 operator==(const Vec128 a, - const Vec128 b) { - // Same as signed ==; avoid duplicating the SSSE3 version. - const DFromV d; - RebindToUnsigned du; - return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +// Any <= 32 bit except +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + // Clang ArgumentPromotionPass seems to break this code. We can unpoison + // before SetTableIndices -> LoadU -> Load and the memory is poisoned again. + detail::MaybeUnpoison(p, Lanes(d)); + +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = Zero(Full128>()).raw; + CopyBytes(p, &v); // not same size as VFromD +#else + int32_t bits = 0; + CopyBytes(p, &bits); // not same size as VFromD + const __m128i v = _mm_cvtsi32_si128(bits); +#endif + return BitCast(d, VFromD{v}); } -// Float -template -HWY_API Mask128 operator==(Vec128 a, Vec128 b) { - return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +// For < 128 bit, LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); } -template -HWY_API Mask128 operator==(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); } -// ------------------------------ Inequality +// ------------------------------ Store -// This cannot have T as a template argument, otherwise it is not more -// specialized than rewritten operator== in C++20, leading to compile -// errors: https://gcc.godbolt.org/z/xsrPhPvPT. -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec128 v, D, float16_t* HWY_RESTRICT aligned) { + _mm_store_ph(aligned, v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + Store(BitCast(du, v), du, reinterpret_cast(aligned)); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +template +HWY_API void Store(Vec128 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +template +HWY_API void Store(Vec128 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec128 v, D, float16_t* HWY_RESTRICT p) { + _mm_storeu_ph(p, v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Not(a == b); +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreU(BitCast(du, v), du, reinterpret_cast(p)); } - -template -HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { - return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +template +HWY_API void StoreU(Vec128 v, D /* tag */, float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); } -template -HWY_API Mask128 operator!=(Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +template +HWY_API void StoreU(Vec128 v, D /* tag */, double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); } -// ------------------------------ Strict inequality - -namespace detail { - -template -HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + (void)d; + CopyBytes<8>(&v, p); // not same size +#else + const RebindToUnsigned du; // for float16_t + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), BitCast(du, v).raw); +#endif } -template -HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +template +HWY_API void Store(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif } -template -HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +template +HWY_API void Store(Vec64 v, D /* tag */, double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif } -template -HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, - const Vec128 a, - const Vec128 b) { -#if HWY_TARGET >= HWY_SSSE3 - // See https://stackoverflow.com/questions/65166174/: - const DFromV d; - const RepartitionToNarrow d32; - const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; - const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; - // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: - // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. - const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; - // Duplicate upper to lower half. - return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +// Any <= 32 bit except +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + CopyBytes(&v, p); // not same size +} +template +HWY_API void Store(Vec32 v, D /* tag */, float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size #else - return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 + _mm_store_ss(p, v.raw); #endif } -template -HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, - Vec128 b) { - const DFromV du; - const RebindToSigned di; - const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); - const auto sa = BitCast(di, Xor(a, msb)); - const auto sb = BitCast(di, Xor(b, msb)); - return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); } -template -HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; -} -template -HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; -} +// ================================================== SWIZZLE (1) -} // namespace detail +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const DFromV d; + const Repartition du8; -template -HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { - return detail::Gt(hwy::TypeTag(), a, b); -} + const DFromV d_bytes; + const Repartition du8_bytes; +#if HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint8_t GccU8RawVectType __attribute__((__vector_size__(16))); + (void)d; + (void)du8; + (void)d_bytes; + (void)du8_bytes; + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(bytes.raw), + reinterpret_cast(from.raw)))}; +#else + const Full128 du8_full; -// ------------------------------ Weak inequality + alignas(16) uint8_t result_bytes[16]; + alignas(16) uint8_t u8_bytes[16]; + alignas(16) uint8_t from_bytes[16]; -namespace detail { -template -HWY_INLINE Mask128 Ge(hwy::SignedTag tag, Vec128 a, - Vec128 b) { - return Not(Gt(tag, b, a)); -} + Store(Vec128{BitCast(du8_bytes, bytes).raw}, du8_full, u8_bytes); + Store(Vec128{BitCast(du8, from).raw}, du8_full, from_bytes); -template -HWY_INLINE Mask128 Ge(hwy::UnsignedTag tag, Vec128 a, - Vec128 b) { - return Not(Gt(tag, b, a)); -} + for (int i = 0; i < 16; i++) { + result_bytes[i] = u8_bytes[from_bytes[i] & 15]; + } -template -HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; -} -template -HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, - Vec128 b) { - return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; + return BitCast(d, VFromD{Load(du8_full, result_bytes).raw}); +#endif +#else // SSSE3 or newer + return BitCast( + d, VFromD{_mm_shuffle_epi8(BitCast(du8_bytes, bytes).raw, + BitCast(du8, from).raw)}); +#endif } -} // namespace detail +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80 on SSSE3/SSE4/AVX2/AVX3 +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { +#if HWY_TARGET == HWY_SSE2 + const DFromV d; + const Repartition di8; -template -HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { - return detail::Ge(hwy::TypeTag(), a, b); + const auto di8_from = BitCast(di8, from); + return BitCast(d, IfThenZeroElse(di8_from < Zero(di8), + TableLookupBytes(bytes, di8_from))); +#else + return TableLookupBytes(bytes, from); +#endif } -#endif // HWY_TARGET <= HWY_AVX3 +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) -// ------------------------------ Reversed comparisons +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. +// Swap 32-bit halves in 64-bit halves. template -HWY_API Mask128 operator<(Vec128 a, Vec128 b) { - return b > a; +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; } - -template -HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { - return b >= a; +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; } -// ------------------------------ Iota (Load) - +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. namespace detail { -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{_mm_set_epi8( - static_cast(15), static_cast(14), static_cast(13), - static_cast(12), static_cast(11), static_cast(10), - static_cast(9), static_cast(8), static_cast(7), - static_cast(6), static_cast(5), static_cast(4), - static_cast(3), static_cast(2), static_cast(1), - static_cast(0))}; -} - -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{_mm_set_epi16(int16_t{7}, int16_t{6}, int16_t{5}, int16_t{4}, - int16_t{3}, int16_t{2}, int16_t{1}, - int16_t{0})}; +template +HWY_API Vec32 ShuffleTwo2301(const Vec32 a, const Vec32 b) { + const DFromV d; + const Twice d2; + const auto ba = Combine(d2, b, a); +#if HWY_TARGET == HWY_SSE2 + Vec32 ba_shuffled{ + _mm_shufflelo_epi16(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return BitCast(d, Or(ShiftLeft<8>(ba_shuffled), ShiftRight<8>(ba_shuffled))); +#else + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 1, 0, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif } - -#if HWY_HAVE_FLOAT16 -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{_mm_set_ph(float16_t{7}, float16_t{6}, float16_t{5}, - float16_t{4}, float16_t{3}, float16_t{2}, - float16_t{1}, float16_t{0})}; +template +HWY_API Vec64 ShuffleTwo2301(const Vec64 a, const Vec64 b) { + const DFromV d; + const Twice d2; + const auto ba = Combine(d2, b, a); +#if HWY_TARGET == HWY_SSE2 + Vec64 ba_shuffled{ + _mm_shuffle_epi32(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return Vec64{ + _mm_shufflelo_epi16(ba_shuffled.raw, _MM_SHUFFLE(2, 3, 0, 1))}; +#else + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0302, 0x0100, 0x0f0e, 0x0d0c, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif } -#endif // HWY_HAVE_FLOAT16 - -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{ - _mm_set_epi32(int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); } -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{_mm_set_epi64x(int64_t{1}, int64_t{0})}; +template +HWY_API Vec32 ShuffleTwo1230(const Vec32 a, const Vec32 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(d); + const Rebind di16; + const Vec32 a_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; + const Vec32 b_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; + const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); + return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 0, 3, 6, 5, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif } - -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{_mm_set_ps(3.0f, 2.0f, 1.0f, 0.0f)}; +template +HWY_API Vec64 ShuffleTwo1230(const Vec64 a, const Vec64 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const Vec32 a_shuffled{ + _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + const Vec32 b_shuffled{ + _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(1, 2, 1, 2))}; + return Combine(d, b_shuffled, a_shuffled); +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0100, 0x0706, 0x0d0c, 0x0b0a, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif } - -template -HWY_INLINE VFromD Iota0(D /*d*/) { - return VFromD{_mm_set_pd(1.0, 0.0)}; +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); } -#if HWY_COMPILER_MSVC -template -static HWY_INLINE V MaskOutVec128Iota(V v) { - const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFF)}; - return v & mask_out_mask; +template +HWY_API Vec32 ShuffleTwo3012(const Vec32 a, const Vec32 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(d); + const Rebind di16; + const Vec32 a_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; + const Vec32 b_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; + const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); + return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 2, 1, 4, 7, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif } -template -static HWY_INLINE V MaskOutVec128Iota(V v) { -#if HWY_TARGET <= HWY_SSE4 - return V{_mm_blend_epi16(v.raw, _mm_setzero_si128(), 0xFE)}; +template +HWY_API Vec64 ShuffleTwo3012(const Vec64 a, const Vec64 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const Vec32 a_shuffled{ + _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(1, 2, 1, 2))}; + const Vec32 b_shuffled{ + _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return Combine(d, b_shuffled, a_shuffled); #else - const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFFFF)}; - return v & mask_out_mask; + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0504, 0x0302, 0x0908, 0x0f0e, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; #endif } -template -static HWY_INLINE V MaskOutVec128Iota(V v) { - const DFromV d; - const Repartition df; - using VF = VFromD; - return BitCast(d, VF{_mm_move_ss(_mm_setzero_ps(), BitCast(df, v).raw)}); -} -template -static HWY_INLINE V MaskOutVec128Iota(V v) { - const DFromV d; - const RebindToUnsigned du; - using VU = VFromD; - return BitCast(d, VU{_mm_move_epi64(BitCast(du, v).raw)}); -} -template -static HWY_INLINE V MaskOutVec128Iota(V v) { - return v; +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); } -#endif } // namespace detail -template -HWY_API VFromD Iota(D d, const T2 first) { - const auto result_iota = - detail::Iota0(d) + Set(d, static_cast>(first)); -#if HWY_COMPILER_MSVC - return detail::MaskOutVec128Iota(result_iota); -#else - return result_iota; -#endif +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; } - -// ------------------------------ FirstN (Iota, Lt) - -template , HWY_IF_V_SIZE_LE_D(D, 16)> -HWY_API M FirstN(D d, size_t num) { - constexpr size_t kN = MaxLanes(d); - // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks - // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. - num = HWY_MIN(num, kN); -#if HWY_TARGET <= HWY_AVX3 -#if HWY_ARCH_X86_64 - const uint64_t all = (1ull << kN) - 1; - return M::FromBits(_bzhi_u64(all, num)); -#else - const uint32_t all = static_cast((1ull << kN) - 1); - return M::FromBits(_bzhi_u32(all, static_cast(num))); -#endif // HWY_ARCH_X86_64 -#else // HWY_TARGET > HWY_AVX3 - const RebindToSigned di; // Signed comparisons are cheaper. - using TI = TFromD; - return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); -#endif // HWY_TARGET <= HWY_AVX3 +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; } - -// ------------------------------ InterleaveLower - -// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides -// the least-significant lane) and "b". To concatenate two half-width integers -// into one, use ZipLower/Upper instead (also works with scalar). - -template -HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { - return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { - const DFromV d; - const RebindToUnsigned du; - using VU = VFromD; // for float16_t - return BitCast( - d, VU{_mm_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { - return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { - return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; } -template -HWY_API Vec128 InterleaveLower(Vec128 a, - Vec128 b) { - return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; } - -// Generic for all vector lengths. -template -HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { - return InterleaveLower(a, b); +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; } -// ================================================== MEMORY (2) +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} -// ------------------------------ MaskedLoad +// ================================================== COMPARE #if HWY_TARGET <= HWY_AVX3 -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_maskz_loadu_epi8(m.raw, p)}; -} +// Comparisons set a mask bit to 1 if the condition is true, else 0. -template -HWY_API VFromD MaskedLoad(MFromD m, D d, - const TFromD* HWY_RESTRICT p) { - const RebindToUnsigned du; // for float16_t - return BitCast(d, VFromD{_mm_maskz_loadu_epi16(m.raw, p)}); -} +// ------------------------------ TestBit -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_maskz_loadu_epi32(m.raw, p)}; -} +namespace detail { -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_maskz_loadu_epi64(m.raw, p)}; +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; } - -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const float* HWY_RESTRICT p) { - return VFromD{_mm_maskz_loadu_ps(m.raw, p)}; +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; } -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const double* HWY_RESTRICT p) { - return VFromD{_mm_maskz_loadu_pd(m.raw, p)}; +} // namespace detail + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); } -template -HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_mask_loadu_epi8(v.raw, m.raw, p)}; +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; } -template -HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_mask_loadu_epi16(v.raw, m.raw, p)}; +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; } -template -HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_mask_loadu_epi32(v.raw, m.raw, p)}; +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; } -template -HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - return VFromD{_mm_mask_loadu_epi64(v.raw, m.raw, p)}; +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; } -template -HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, - const float* HWY_RESTRICT p) { - return VFromD{_mm_mask_loadu_ps(v.raw, m.raw, p)}; +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } -template -HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, - const double* HWY_RESTRICT p) { - return VFromD{_mm_mask_loadu_pd(v.raw, m.raw, p)}; +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } -#elif HWY_TARGET == HWY_AVX2 +// ------------------------------ Inequality -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - auto p_p = reinterpret_cast(p); // NOLINT - return VFromD{_mm_maskload_epi32(p_p, m.raw)}; +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; } -template -HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT p) { - auto p_p = reinterpret_cast(p); // NOLINT - return VFromD{_mm_maskload_epi64(p_p, m.raw)}; +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; } -template -HWY_API VFromD MaskedLoad(MFromD m, D d, const float* HWY_RESTRICT p) { - const RebindToSigned di; +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu64_mask(a.raw, b.raw)}; +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo dto, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(dto, VecFromMask(d, m))); +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d64; + const RepartitionToNarrow d32; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const DFromV d; + const RepartitionToNarrow d32; + const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { +template +HWY_INLINE Mask128 Ge(hwy::SignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::UnsignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ Iota (Load) + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi8( + static_cast(15), static_cast(14), static_cast(13), + static_cast(12), static_cast(11), static_cast(10), + static_cast(9), static_cast(8), static_cast(7), + static_cast(6), static_cast(5), static_cast(4), + static_cast(3), static_cast(2), static_cast(1), + static_cast(0))}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi16(int16_t{7}, int16_t{6}, int16_t{5}, int16_t{4}, + int16_t{3}, int16_t{2}, int16_t{1}, + int16_t{0})}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_ph(float16_t{7}, float16_t{6}, float16_t{5}, + float16_t{4}, float16_t{3}, float16_t{2}, + float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm_set_epi32(int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi64x(int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_ps(3.0f, 2.0f, 1.0f, 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_pd(1.0, 0.0)}; +} + +#if HWY_COMPILER_MSVC +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFF)}; + return v & mask_out_mask; +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { +#if HWY_TARGET <= HWY_SSE4 + return V{_mm_blend_epi16(v.raw, _mm_setzero_si128(), 0xFE)}; +#else + const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFFFF)}; + return v & mask_out_mask; +#endif +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const DFromV d; + const Repartition df; + using VF = VFromD; + return BitCast(d, VF{_mm_move_ss(_mm_setzero_ps(), BitCast(df, v).raw)}); +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm_move_epi64(BitCast(du, v).raw)}); +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + return v; +} +#endif + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + const auto result_iota = + detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +#if HWY_COMPILER_MSVC + return detail::MaskOutVec128Iota(result_iota); +#else + return result_iota; +#endif +} + +// ------------------------------ FirstN (Iota, Lt) + +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API M FirstN(D d, size_t num) { + constexpr size_t kN = MaxLanes(d); + // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. + num = HWY_MIN(num, kN); +#if HWY_TARGET <= HWY_AVX3 +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << kN) - 1; + return M::FromBits(_bzhi_u64(all, num)); +#else + const uint32_t all = static_cast((1ull << kN) - 1); + return M::FromBits(_bzhi_u32(all, static_cast(num))); +#endif // HWY_ARCH_X86_64 +#else // HWY_TARGET > HWY_AVX3 + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); +#endif // HWY_TARGET <= HWY_AVX3 +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ================================================== MEMORY (2) + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_maskz_loadu_epi16(m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_pd(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{ + _mm_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_pd(v.raw, m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return VFromD{_mm_maskload_epi32(p_p, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return VFromD{_mm_maskload_epi64(p_p, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, const float* HWY_RESTRICT p) { + const RebindToSigned di; return VFromD{_mm_maskload_ps(p, BitCast(di, VecFromMask(d, m)).raw)}; } @@ -3216,12 +3903,47 @@ HWY_API Vec128 operator-(const Vec128 a, return Vec128{_mm_sub_pd(a.raw, b.raw)}; } +// ------------------------------ AddSub + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_V_SIZE_GT_V( \ + V, ((hwy::IsFloat3264>()) ? 32 : sizeof(TFromV))) + +template +HWY_API Vec128 AddSub(Vec128 a, Vec128 b) { + return Vec128{_mm_addsub_ps(a.raw, b.raw)}; +} +HWY_API Vec128 AddSub(Vec128 a, Vec128 b) { + return Vec128{_mm_addsub_pd(a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_SSSE3 + // ------------------------------ SumsOf8 template HWY_API Vec128 SumsOf8(const Vec128 v) { return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; } +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOf8(V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition di64; + + // Adjust the values of v to be in the 0..255 range by adding 128 to each lane + // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then + // bitcasting the Xor result to an u8 vector. + const auto v_adj = BitCast(du, Xor(v, SignBit(d))); + + // Need to add -1024 to each i64 lane of the result of the SumsOf8(v_adj) + // operation to account for the adjustment made above. + return BitCast(di64, SumsOf8(v_adj)) + Set(di64, int64_t{-1024}); +} + #ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF #undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF #else @@ -3234,6 +3956,136 @@ HWY_API Vec128 SumsOf8AbsDiff(const Vec128 a, return Vec128{_mm_sad_epu8(a.raw, b.raw)}; } +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOf8AbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX3 di64; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of SumsOf8AbsDiff(a_adj, b_adj) can simply be bitcasted to an + // i64 vector as |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true + return BitCast(di64, SumsOf8AbsDiff(a_adj, b_adj)); +} + +// ------------------------------ SumsOf4 +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE Vec128 SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, + Vec128 v) { + const DFromV d; + + // _mm_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm_maskz_dbsad_epu8 result. + return Vec128{ + _mm_maskz_dbsad_epu8(static_cast<__mmask8>(0x55), v.raw, Zero(d).raw, 0)}; +} + +// detail::SumsOf4 for Vec128 on AVX3 is implemented in x86_512-inl.h + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ SumsOfAdjQuadAbsDiff + +#if HWY_TARGET <= HWY_SSE4 +#ifdef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOfAdjQuadAbsDiff( + Vec128 a, Vec128 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + return Vec128{ + _mm_mpsadbw_epu8(a.raw, b.raw, (kAOffset << 2) | kBOffset)}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOfAdjQuadAbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of SumsOfAdjQuadAbsDiff(a_adj, b_adj) can + // simply be bitcasted to an i16 vector as + // |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true. + return BitCast(dw, SumsOfAdjQuadAbsDiff(a_adj, b_adj)); +} +#endif + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOfShuffledQuadAbsDiff( + Vec128 a, Vec128 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec128{ + _mm_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOfShuffledQuadAbsDiff(V a, + V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of + // SumsOfShuffledQuadAbsDiff(a_adj, b_adj) can + // simply be bitcasted to an i16 vector as + // |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true. + return BitCast( + dw, SumsOfShuffledQuadAbsDiff(a_adj, b_adj)); +} +#endif + // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -3262,7 +4114,7 @@ HWY_API Vec128 SaturatedAdd(const Vec128 a, return Vec128{_mm_adds_epi16(a.raw, b.raw)}; } -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN #ifdef HWY_NATIVE_I32_SATURATED_ADDSUB #undef HWY_NATIVE_I32_SATURATED_ADDSUB #else @@ -3300,7 +4152,7 @@ HWY_API Vec128 SaturatedAdd(Vec128 a, i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; return IfThenElse(overflow_mask, overflow_result, sum); } -#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN // ------------------------------ SaturatedSub @@ -3330,7 +4182,7 @@ HWY_API Vec128 SaturatedSub(const Vec128 a, return Vec128{_mm_subs_epi16(a.raw, b.raw)}; } -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN template HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { @@ -3356,7 +4208,7 @@ HWY_API Vec128 SaturatedSub(Vec128 a, i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; return IfThenElse(overflow_mask, overflow_result, diff); } -#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN // ------------------------------ AverageRound @@ -3374,6 +4226,18 @@ HWY_API Vec128 AverageRound(const Vec128 a, return Vec128{_mm_avg_epu16(a.raw, b.raw)}; } +// I8/I16 AverageRound is generic for all vector lengths +template +HWY_API V AverageRound(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const V sign_bit = SignBit(d); + return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)), + BitCast(du, Xor(b, sign_bit)))), + sign_bit); +} + // ------------------------------ Integer multiplication template @@ -3387,7 +4251,7 @@ HWY_API Vec128 operator*(const Vec128 a, return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; } -// Returns the upper 16 bits of a * b in each lane. +// Returns the upper sizeof(T)*8 bits of a * b in each lane. template HWY_API Vec128 MulHigh(const Vec128 a, const Vec128 b) { @@ -3399,6 +4263,26 @@ HWY_API Vec128 MulHigh(const Vec128 a, return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; } +template , 1)> +HWY_API V MulHigh(V a, V b) { + const DFromV d; + const Full128> d_full; + return ResizeBitCast( + d, Slide1Down(d_full, ResizeBitCast(d_full, MulEven(a, b)))); +} + +// I8/U8/I32/U32 MulHigh is generic for all vector lengths >= 2 lanes +template , 1)> +HWY_API V MulHigh(V a, V b) { + const DFromV d; + + const auto p_even = BitCast(d, MulEven(a, b)); + const auto p_odd = BitCast(d, MulOdd(a, b)); + return InterleaveOdd(d, p_even, p_odd); +} + // Multiplies even lanes (0, 2 ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. template )> @@ -3526,15 +4410,29 @@ HWY_API Vec128 operator*(const Vec128 a, // ------------------------------ RotateRight (ShiftRight, Or) -template -HWY_API Vec128 RotateRight(const Vec128 v) { - constexpr size_t kSizeInBits = sizeof(T) * 8; - static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); +// U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8 +// RotateRight uses detail::GaloisAffine on AVX3_DL + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); if (kBits == 0) return v; - // AVX3 does not support 8/16-bit. - return Or(ShiftRight(v), - ShiftLeft(v)); +#if HWY_TARGET <= HWY_AVX3_DL + return Vec128{_mm_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif } template @@ -3559,6 +4457,116 @@ HWY_API Vec128 RotateRight(const Vec128 v) { #endif } +// I8/I16/I32/I64 RotateRight is generic for all vector lengths +template +HWY_API V RotateRight(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, RotateRight(BitCast(du, v))); +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_shrdv_epi16(a.raw, a.raw, b.raw)}; +} + +// U16/I16 Rol is generic for all vector lengths on AVX3_DL +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + return Ror(a, BitCast(d, Neg(BitCast(di, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + return Vec128{_mm_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + return Vec128{_mm_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_rorv_epi64(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ RotateLeftSame/RotateRightSame + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +// Generic for all vector lengths +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + return Ror(v, + Set(d, static_cast>(0u - static_cast(bits)))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + return Ror(v, Set(d, static_cast>(bits))); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +// Generic for all vector lengths +template +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast>(static_cast(bits)))); +} + +template +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + return Ror(v, Set(d, static_cast>(static_cast(bits)))); +} +#endif // HWY_TARGET <= HWY_AVX3 + // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) template @@ -3631,16 +4639,62 @@ HWY_API Vec128 Abs(const Vec128 v) { #endif } +#if HWY_TARGET <= HWY_AVX3 template HWY_API Vec128 Abs(const Vec128 v) { -#if HWY_TARGET <= HWY_AVX3 return Vec128{_mm_abs_epi64(v.raw)}; +} +#else +// I64 Abs is generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template )> +HWY_API V Abs(V v) { + const auto zero = Zero(DFromV()); + return IfNegativeThenElse(v, zero - v, v); +} +#endif + +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, v), BitCast(du, SaturatedSub(Zero(d), v)))); +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedSub(Zero(DFromV()), v)); +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + +#if HWY_TARGET <= HWY_SSE4 + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, abs_v), + Set(du, static_cast(LimitsMax())))); #else - const auto zero = Zero(DFromV()); - return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); + return Add(abs_v, BroadcastSignBit(abs_v)); #endif } +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + return Add(abs_v, BroadcastSignBit(abs_v)); +} + // GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512VL // srli_epi64: the count should be unsigned int. Note that this is not the same // as the Shift3264Count in x86_512-inl.h (GCC also requires int). @@ -3666,20 +4720,6 @@ HWY_API Vec128 ShiftRight(const Vec128 v) { #endif } -// ------------------------------ ZeroIfNegative (BroadcastSignBit) -template -HWY_API Vec128 ZeroIfNegative(Vec128 v) { - static_assert(IsFloat(), "Only works for float"); - const DFromV d; -#if HWY_TARGET >= HWY_SSSE3 - const RebindToSigned di; - const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); -#else - const auto mask = MaskFromVec(v); // MSB is sufficient for BLENDVPS -#endif - return IfThenElse(mask, Zero(d), v); -} - // ------------------------------ IfNegativeThenElse template HWY_API Vec128 IfNegativeThenElse(const Vec128 v, @@ -3743,6 +4783,91 @@ HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, #endif } +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +// SSE4/AVX2 IfNegativeThenElseZero/IfNegativeThenZeroElse is generic for all +// vector lengths +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + const DFromV d; + return IfNegativeThenElse(v, yes, Zero(d)); +} + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + return IfThenElseZero(IsNegative(v), yes); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + const DFromV d; + return IfNegativeThenElse(v, Zero(d), no); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + return IfThenZeroElse(IsNegative(v), no); +} + +#endif // HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#if HWY_TARGET <= HWY_SSSE3 + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero(Vec128 mask, + Vec128 v) { + return Vec128{_mm_sign_epi8(v.raw, mask.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{_mm_sign_epi16(v.raw, mask.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{_mm_sign_epi32(v.raw, mask.raw)}; +} + +// Generic for all vector lengths +template )> +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { +#if HWY_TARGET <= HWY_AVX3 + // MaskedSubOr is more efficient than IfNegativeThenElse on AVX3 + const DFromV d; + return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); +#else + // IfNegativeThenElse is more efficient than MaskedSubOr on SSE4/AVX2 + return IfNegativeThenElse(mask, Neg(v), v); +#endif +} + +#endif // HWY_TARGET <= HWY_SSSE3 + // ------------------------------ ShiftLeftSame template @@ -3938,6 +5063,43 @@ HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { return Vec64{_mm_mul_sd(a.raw, b.raw)}; } +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_MUL_BY_POW2 +#undef HWY_NATIVE_MUL_BY_POW2 +#else +#define HWY_NATIVE_MUL_BY_POW2 +#endif + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_ph(a.raw, b.raw)}; +} +#endif + +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_ps(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_pd(a.raw, b.raw)}; +} + +// MulByPow2 is generic for all vector lengths on AVX3 +template +HWY_API V MulByPow2(V v, VFromD>> exp) { + const DFromV d; + return MulByFloorPow2(v, ConvertTo(d, exp)); +} + +#endif // HWY_TARGET <= HWY_AVX3 + #if HWY_HAVE_FLOAT16 template HWY_API Vec128 operator/(const Vec128 a, @@ -4000,6 +5162,361 @@ HWY_API V AbsDiff(V a, V b) { return Abs(a - b); } +// ------------------------------ MaskedMinOr + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +// There are no elementwise integer mask_mul. Generic for all vector lengths. +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, a * b, no); +} + +template +HWY_API Vec128 MaskedMulOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMulOr(Vec128 no, + Mask128 m, Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMulOr(Vec128 no, + Mask128 m, + Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +template +HWY_API Vec128 MaskedDivOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedDivOr(Vec128 no, + Mask128 m, Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedDivOr(Vec128 no, + Mask128 m, + Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// Generic for all vector lengths +template +HWY_API V MaskedDivOr(V no, MFromD> m, V a, V b) { + return IfThenElse(m, Div(a, b), no); +} + +// ------------------------------ MaskedModOr +// Generic for all vector lengths +template +HWY_API V MaskedModOr(V no, MFromD> m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + // ------------------------------ Floating-point multiply-add variants #if HWY_HAVE_FLOAT16 @@ -4035,7 +5552,7 @@ HWY_API Vec128 NegMulSub(Vec128 mul, template HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, Vec128 add) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return mul * x + add; #else return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; @@ -4044,7 +5561,7 @@ HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, template HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, Vec128 add) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return mul * x + add; #else return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; @@ -4055,7 +5572,7 @@ HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, template HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, Vec128 add) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return add - mul * x; #else return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; @@ -4064,7 +5581,7 @@ HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, template HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, Vec128 add) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return add - mul * x; #else return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; @@ -4075,7 +5592,7 @@ HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, template HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, Vec128 sub) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return mul * x - sub; #else return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; @@ -4084,7 +5601,7 @@ HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, template HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, Vec128 sub) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return mul * x - sub; #else return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; @@ -4095,7 +5612,7 @@ HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, template HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, Vec128 sub) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return Neg(mul) * x - sub; #else return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; @@ -4104,13 +5621,53 @@ HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, template HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, Vec128 sub) { -#if HWY_TARGET >= HWY_SSE4 +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) return Neg(mul) * x - sub; #else return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; #endif } +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + HWY_IF_T_SIZE_ONE_OF_V( \ + V, (1 << 1) | ((hwy::IsFloat>()) \ + ? 0 \ + : ((1 << 2) | (1 << 4) | (1 << 8)))) + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAddSub(Vec128 mul, + Vec128 x, + Vec128 sub_or_add) { + return Vec128{_mm_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 MulAddSub(Vec128 mul, Vec128 x, + Vec128 sub_or_add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return AddSub(mul * x, sub_or_add); +#else + return Vec128{_mm_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +HWY_API Vec128 MulAddSub(Vec128 mul, Vec128 x, + Vec128 sub_or_add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return AddSub(mul * x, sub_or_add); +#else + return Vec128{_mm_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +#endif // HWY_TARGET <= HWY_SSSE3 + // ------------------------------ Floating-point square root // Full precision square root @@ -4196,7 +5753,8 @@ HWY_API Vec128 Min(Vec128 a, Vec128 b) { template HWY_API Vec128 Min(Vec128 a, Vec128 b) { #if HWY_TARGET >= HWY_SSSE3 - return detail::MinU(a, b); + return Vec128{ + _mm_sub_epi16(a.raw, _mm_subs_epu16(a.raw, b.raw))}; #else return Vec128{_mm_min_epu16(a.raw, b.raw)}; #endif @@ -4289,7 +5847,8 @@ HWY_API Vec128 Max(Vec128 a, Vec128 b) { template HWY_API Vec128 Max(Vec128 a, Vec128 b) { #if HWY_TARGET >= HWY_SSSE3 - return detail::MaxU(a, b); + return Vec128{ + _mm_add_epi16(a.raw, _mm_subs_epu16(b.raw, a.raw))}; #else return Vec128{_mm_max_epu16(a.raw, b.raw)}; #endif @@ -4508,116 +6067,120 @@ HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, namespace detail { -template -HWY_INLINE VFromD NativeGather128(D /* tag */, - const TFromD* HWY_RESTRICT base, - VI index) { - return VFromD{_mm_i32gather_epi32(reinterpret_cast(base), - index.raw, kScale)}; +template +HWY_INLINE Vec128 NativeGather128(const T* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), indices.raw, kScale)}; } -template -HWY_INLINE VFromD NativeGather128(D /* tag */, - const TFromD* HWY_RESTRICT base, - VI index) { - return VFromD{_mm_i64gather_epi64( - reinterpret_cast(base), index.raw, kScale)}; +template +HWY_INLINE Vec128 NativeGather128(const T* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), indices.raw, kScale)}; } -template -HWY_INLINE VFromD NativeGather128(D /* tag */, - const float* HWY_RESTRICT base, VI index) { - return VFromD{_mm_i32gather_ps(base, index.raw, kScale)}; +template +HWY_INLINE Vec128 NativeGather128(const float* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i32gather_ps(base, indices.raw, kScale)}; } -template -HWY_INLINE VFromD NativeGather128(D /* tag */, - const double* HWY_RESTRICT base, - VI index) { - return VFromD{_mm_i64gather_pd(base, index.raw, kScale)}; +template +HWY_INLINE Vec128 NativeGather128(const double* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i64gather_pd(base, indices.raw, kScale)}; } - -template -HWY_INLINE VFromD NativeMaskedGather128(MFromD m, D d, - const TFromD* HWY_RESTRICT base, - VI index) { - // For partial vectors, ensure upper mask lanes are zero to prevent faults. - if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128(Vec128 no, + Mask128 m, + const T* HWY_RESTRICT base, + Vec128 indices) { #if HWY_TARGET <= HWY_AVX3 - return VFromD{_mm_mmask_i32gather_epi32( - Zero(d).raw, m.raw, index.raw, reinterpret_cast(base), + return Vec128{_mm_mmask_i32gather_epi32( + no.raw, m.raw, indices.raw, reinterpret_cast(base), kScale)}; #else - return VFromD{_mm_mask_i32gather_epi32( - Zero(d).raw, reinterpret_cast(base), index.raw, m.raw, - kScale)}; + return Vec128{ + _mm_mask_i32gather_epi32(no.raw, reinterpret_cast(base), + indices.raw, m.raw, kScale)}; #endif } -template -HWY_INLINE VFromD NativeMaskedGather128(MFromD m, D d, - const TFromD* HWY_RESTRICT base, - VI index) { - // For partial vectors, ensure upper mask lanes are zero to prevent faults. - if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); +template +HWY_INLINE Vec128 NativeMaskedGatherOr128(Vec128 no, + Mask128 m, + const T* HWY_RESTRICT base, + Vec128 indices) { #if HWY_TARGET <= HWY_AVX3 - return VFromD{_mm_mmask_i64gather_epi64( - Zero(d).raw, m.raw, index.raw, - reinterpret_cast(base), kScale)}; + return Vec128{_mm_mmask_i64gather_epi64( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; #else - return VFromD{_mm_mask_i64gather_epi64( - Zero(d).raw, reinterpret_cast(base), index.raw, - m.raw, kScale)}; + return Vec128{_mm_mask_i64gather_epi64( + no.raw, reinterpret_cast(base), indices.raw, m.raw, + kScale)}; #endif } -template -HWY_INLINE VFromD NativeMaskedGather128(MFromD m, D d, - const float* HWY_RESTRICT base, - VI index) { - // For partial vectors, ensure upper mask lanes are zero to prevent faults. - if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); +template +HWY_INLINE Vec128 NativeMaskedGatherOr128( + Vec128 no, Mask128 m, const float* HWY_RESTRICT base, + Vec128 indices) { #if HWY_TARGET <= HWY_AVX3 - return VFromD{ - _mm_mmask_i32gather_ps(Zero(d).raw, m.raw, index.raw, base, kScale)}; + return Vec128{ + _mm_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; #else - return VFromD{ - _mm_mask_i32gather_ps(Zero(d).raw, base, index.raw, m.raw, kScale)}; + return Vec128{ + _mm_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; #endif } -template -HWY_INLINE VFromD NativeMaskedGather128(MFromD m, D d, - const double* HWY_RESTRICT base, - VI index) { - // For partial vectors, ensure upper mask lanes are zero to prevent faults. - if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); +template +HWY_INLINE Vec128 NativeMaskedGatherOr128( + Vec128 no, Mask128 m, const double* HWY_RESTRICT base, + Vec128 indices) { #if HWY_TARGET <= HWY_AVX3 - return VFromD{ - _mm_mmask_i64gather_pd(Zero(d).raw, m.raw, index.raw, base, kScale)}; + return Vec128{ + _mm_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; #else - return VFromD{ - _mm_mask_i64gather_pd(Zero(d).raw, base, index.raw, m.raw, kScale)}; + return Vec128{ + _mm_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; #endif } } // namespace detail -template , class VI> -HWY_API VFromD GatherOffset(D d, const T* HWY_RESTRICT base, VI offset) { - static_assert(sizeof(T) == sizeof(TFromV), "Index/lane size must match"); - return detail::NativeGather128<1>(d, base, offset); +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather128<1>(base, offsets); +} + +template > +HWY_API VFromD GatherIndex(D /*d*/, const T* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather128(base, indices); } -template , class VI> -HWY_API VFromD GatherIndex(D d, const T* HWY_RESTRICT base, VI index) { - static_assert(sizeof(T) == sizeof(TFromV), "Index/lane size must match"); - return detail::NativeGather128(d, base, index); + +template > +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> indices) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + + return detail::NativeMaskedGatherOr128(no, m, base, indices); } -template , class VI> + +// Generic for all vector lengths. +template HWY_API VFromD MaskedGatherIndex(MFromD m, D d, - const T* HWY_RESTRICT base, VI index) { - static_assert(sizeof(T) == sizeof(TFromV), "Index/lane size must match"); - return detail::NativeMaskedGather128(m, d, base, index); + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return MaskedGatherIndexOr(Zero(d), m, d, base, indices); } #endif // HWY_TARGET <= HWY_AVX2 @@ -4740,9 +6303,7 @@ HWY_INLINE T ExtractLane(const Vec128 v) { const RebindToUnsigned du; const uint16_t lane = static_cast( _mm_extract_epi16(BitCast(du, v).raw, kLane) & 0xFFFF); - T ret; - CopySameSize(&lane, &ret); // for float16_t - return ret; + return BitCastScalar(lane); } template @@ -4780,9 +6341,7 @@ HWY_INLINE float ExtractLane(const Vec128 v) { #else // Bug in the intrinsic, returns int but should be float. const int32_t bits = _mm_extract_ps(v.raw, kLane); - float ret; - CopySameSize(&bits, &ret); - return ret; + return BitCastScalar(bits); #endif } @@ -4958,8 +6517,7 @@ HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { static_assert(kLane < N, "Lane index out of bounds"); const DFromV d; const RebindToUnsigned du; - uint16_t bits; - CopySameSize(&t, &bits); // for float16_t + const uint16_t bits = BitCastScalar(t); return BitCast(d, VFromD{ _mm_insert_epi16(BitCast(du, v).raw, bits, kLane)}); } @@ -4970,8 +6528,7 @@ HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { #if HWY_TARGET >= HWY_SSSE3 return InsertLaneUsingBroadcastAndBlend(v, kLane, t); #else - MakeSigned ti; - CopySameSize(&t, &ti); // don't just cast because T might be float. + const MakeSigned ti = BitCastScalar>(t); return Vec128{_mm_insert_epi32(v.raw, ti, kLane)}; #endif } @@ -4990,8 +6547,7 @@ HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { return BitCast( d, Vec128{_mm_shuffle_pd(BitCast(df, v).raw, vt.raw, 0)}); #else - MakeSigned ti; - CopySameSize(&t, &ti); // don't just cast because T might be float. + const MakeSigned ti = BitCastScalar>(t); return Vec128{_mm_insert_epi64(v.raw, ti, kLane)}; #endif } @@ -5325,414 +6881,635 @@ HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { return Indices128{vec.raw}; } -template -HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( - D d, const TI* idx) { - static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); - const Rebind di; - return IndicesFromVec(d, LoadU(di, idx)); -} +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return {_mm_permutexvar_epi16(idx.raw, v.raw)}; +#elif HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint16_t GccU16RawVectType __attribute__((__vector_size__(16))); + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(v.raw), + reinterpret_cast(idx.raw)))}; +#else + const Full128 d_full; + alignas(16) T src_lanes[8]; + alignas(16) uint16_t indices[8]; + alignas(16) T result_lanes[8]; + + Store(Vec128{v.raw}, d_full, src_lanes); + _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); + + for (int i = 0; i < 8; i++) { + result_lanes[i] = src_lanes[indices[i] & 7u]; + } + + return Vec128{Load(d_full, result_lanes).raw}; +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#else + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + return {_mm_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + const DFromV d; + const RebindToFloat df; + const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; + return BitCast(d, perm); +#elif HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(v.raw), + reinterpret_cast(idx.raw)))}; +#else + const Full128 d_full; + alignas(16) T src_lanes[4]; + alignas(16) uint32_t indices[4]; + alignas(16) T result_lanes[4]; + + Store(Vec128{v.raw}, d_full, src_lanes); + _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); + + for (int i = 0; i < 4; i++) { + result_lanes[i] = src_lanes[indices[i] & 3u]; + } + + return Vec128{Load(d_full, result_lanes).raw}; +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#else // SSSE3 or SSE4 + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +#if HWY_TARGET <= HWY_SSSE3 +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; +#else // SSSE3 or SSE4 + const DFromV df; + const RebindToSigned di; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +#endif // HWY_TARGET <= HWY_AVX2 +} +#endif // HWY_TARGET <= HWY_SSSE3 + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const RebindToSigned di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + vidx += vidx; // bit1 is the decider (unusual) + return Vec128{_mm_permutevar_pd(v.raw, vidx.raw)}; +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const DFromV d; + const RebindToSigned di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + return v; +} + +// 32-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return VFromD{Shuffle2301(Vec128>{v.raw}).raw}; +} + +// 64-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle01(v); +} + +// 32-bit x4: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + constexpr size_t kN = MaxLanes(d); + if (kN == 1) return v; + if (kN == 2) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 0, 1))}); + } + if (kN == 4) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET == HWY_SSE2 + const VU rev4{ + _mm_shufflehi_epi16(_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), + _MM_SHUFFLE(0, 1, 2, 3))}; + return BitCast(d, VU{_mm_shuffle_epi32(rev4.raw, _MM_SHUFFLE(1, 0, 3, 2))}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + constexpr int kN = static_cast(MaxLanes(d)); + if (kN == 1) return v; +#if HWY_TARGET <= HWY_SSSE3 + // NOTE: Lanes with negative shuffle control mask values are set to zero. + alignas(16) static constexpr int8_t kReverse[16] = { + kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, + kN - 9, kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16}; + const RebindToSigned di; + const VFromD idx = Load(di, kReverse); + return VFromD{_mm_shuffle_epi8(BitCast(di, v).raw, idx.raw)}; +#else + const RepartitionToWide d16; + return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); +#endif +} + +// ------------------------------ Reverse2 -template -HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { - return TableLookupBytes(v, Vec128{idx.raw}); +// Single lane: no change +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return v; } -template -HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +// Generic for all vector lengths (128-bit sufficient if SSE2). +template +HWY_API VFromD Reverse2(D d, VFromD v) { #if HWY_TARGET <= HWY_AVX3 - return {_mm_permutexvar_epi16(idx.raw, v.raw)}; + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); #elif HWY_TARGET == HWY_SSE2 -#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) - typedef uint16_t GccU16RawVectType __attribute__((__vector_size__(16))); - return Vec128{reinterpret_cast::type>( - __builtin_shuffle(reinterpret_cast(v.raw), - reinterpret_cast(idx.raw)))}; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + constexpr size_t kN = MaxLanes(d); + __m128i shuf_result = _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(2, 3, 0, 1)); + if (kN > 4) { + shuf_result = _mm_shufflehi_epi16(shuf_result, _MM_SHUFFLE(2, 3, 0, 1)); + } + return BitCast(d, VU{shuf_result}); #else - const Full128 d_full; - alignas(16) T src_lanes[8]; - alignas(16) uint16_t indices[8]; - alignas(16) T result_lanes[8]; + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0302, 0x0100, 0x0706, 0x0504, 0x0B0A, 0x0908, 0x0F0E, 0x0D0C); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} - Store(Vec128{v.raw}, d_full, src_lanes); - _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle2301(v); +} - for (int i = 0; i < 8; i++) { - result_lanes[i] = src_lanes[indices[i] & 7u]; +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + // 4x 16-bit: a single shufflelo suffices. + constexpr size_t kN = MaxLanes(d); + if (kN <= 4) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); } - return Vec128{Load(d_full, result_lanes).raw}; -#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#if HWY_TARGET == HWY_SSE2 + return BitCast(d, VU{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), + _MM_SHUFFLE(0, 1, 2, 3))}); #else - return TableLookupBytes(v, Vec128{idx.raw}); + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); + return BitCast(d, TableLookupBytes(v, shuffle)); #endif } -#if HWY_HAVE_FLOAT16 -template -HWY_API Vec128 TableLookupLanes(Vec128 v, - Indices128 idx) { - return {_mm_permutexvar_ph(idx.raw, v.raw)}; +// Generic for all vector lengths. +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return Shuffle0123(v); } -#endif // HWY_HAVE_FLOAT16 -template -HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { -#if HWY_TARGET <= HWY_AVX2 - const DFromV d; - const RebindToFloat df; - const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; - return BitCast(d, perm); -#elif HWY_TARGET == HWY_SSE2 -#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) - typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); - return Vec128{reinterpret_cast::type>( - __builtin_shuffle(reinterpret_cast(v.raw), - reinterpret_cast(idx.raw)))}; +template +HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { +#if HWY_TARGET == HWY_SSE2 + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); #else - const Full128 d_full; - alignas(16) T src_lanes[4]; - alignas(16) uint32_t indices[4]; - alignas(16) T result_lanes[4]; + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} - Store(Vec128{v.raw}, d_full, src_lanes); - _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); +template +HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} - for (int i = 0; i < 4; i++) { - result_lanes[i] = src_lanes[indices[i] & 3u]; - } +// ------------------------------ ReverseBits in x86_512 - return Vec128{Load(d_full, result_lanes).raw}; -#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) -#else // SSSE3 or SSE4 - return TableLookupBytes(v, Vec128{idx.raw}); +// ------------------------------ InterleaveUpper (UpperHalf) + +// Full +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi64(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// -------------------------- I8/U8 Broadcast (InterleaveLower, InterleaveUpper) + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + const DFromV d; + +#if HWY_TARGET == HWY_SSE2 + const Full128 d_full; + const Vec128 v_full{v.raw}; + const auto v_interleaved = (kLane < 8) + ? InterleaveLower(d_full, v_full, v_full) + : InterleaveUpper(d_full, v_full, v_full); + return ResizeBitCast( + d, Broadcast(BitCast(Full128(), v_interleaved))); +#else + return TableLookupBytes(v, Set(d, static_cast(kLane))); #endif } -#if HWY_TARGET <= HWY_SSSE3 -template -HWY_API Vec128 TableLookupLanes(Vec128 v, - Indices128 idx) { -#if HWY_TARGET <= HWY_AVX2 - return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; -#else // SSSE3 or SSE4 - const DFromV df; - const RebindToSigned di; - return BitCast(df, - TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); -#endif // HWY_TARGET <= HWY_AVX2 +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +// Generic for all vector lengths. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); } -#endif // HWY_TARGET <= HWY_SSSE3 -// Single lane: no change -template -HWY_API Vec128 TableLookupLanes(Vec128 v, - Indices128 /* idx */) { - return v; +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); } -template -HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { - const DFromV d; - Vec128 vidx{idx.raw}; -#if HWY_TARGET <= HWY_AVX2 - // There is no _mm_permute[x]var_epi64. - vidx += vidx; // bit1 is the decider (unusual) - const RebindToFloat df; - return BitCast( - d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +// ================================================== CONVERT (1) + +// ------------------------------ PromoteTo unsigned (TableLookupBytesOr0) +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return VFromD{_mm_unpacklo_epi8(v.raw, zero)}; #else - // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit - // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 - // to obtain an all-zero or all-one mask. - const RebindToSigned di; - const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); - const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); - return IfThenElse(mask_same, v, Shuffle01(v)); + return VFromD{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return VFromD{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return VFromD{_mm_cvtepu16_epi32(v.raw)}; #endif } - -HWY_API Vec128 TableLookupLanes(Vec128 v, - Indices128 idx) { - Vec128 vidx{idx.raw}; -#if HWY_TARGET <= HWY_AVX2 - vidx += vidx; // bit1 is the decider (unusual) - return Vec128{_mm_permutevar_pd(v.raw, vidx.raw)}; +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return VFromD{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; #else - // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit - // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 - // to obtain an all-zero or all-one mask. - const DFromV d; - const RebindToSigned di; - const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); - const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); - return IfThenElse(mask_same, v, Shuffle01(v)); + return VFromD{_mm_cvtepu32_epi64(v.raw)}; #endif } - -// ------------------------------ ReverseBlocks - -// Single block: no change -template -HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { - return v; -} - -// ------------------------------ Reverse (Shuffle0123, Shuffle2301) - -// Single lane: no change -template -HWY_API VFromD Reverse(D /* tag */, VFromD v) { - return v; +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return VFromD{_mm_unpacklo_epi16(u16, zero)}; +#else + return VFromD{_mm_cvtepu8_epi32(v.raw)}; +#endif } - -// 32-bit x2: shuffle -template -HWY_API VFromD Reverse(D /* tag */, const VFromD v) { - return VFromD{Shuffle2301(Vec128>{v.raw}).raw}; +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET > HWY_SSSE3 + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) static constexpr int8_t kShuffle[16] = { + 0, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1}; + const Repartition di8; + return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); +#else + (void)d; + return VFromD{_mm_cvtepu8_epi64(v.raw)}; +#endif } - -// 64-bit x2: shuffle -template -HWY_API VFromD Reverse(D /* tag */, const VFromD v) { - return Shuffle01(v); +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET > HWY_SSSE3 + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) static constexpr int8_t kShuffle[16] = { + 0, 1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1, -1, -1}; + const Repartition di8; + return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); +#else + (void)d; + return VFromD{_mm_cvtepu16_epi64(v.raw)}; +#endif } -// 32-bit x4: shuffle -template -HWY_API VFromD Reverse(D /* tag */, const VFromD v) { - return Shuffle0123(v); +// Unsigned to signed: same plus cast. +template ), sizeof(TFromV)), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V))> +HWY_API VFromD PromoteTo(D di, V v) { + const RebindToUnsigned du; + return BitCast(di, PromoteTo(du, v)); } -// 16-bit -template -HWY_API VFromD Reverse(D d, const VFromD v) { - const RebindToUnsigned du; - using VU = VFromD; - const VU vu = BitCast(du, v); // for float16_t - constexpr size_t kN = MaxLanes(d); - if (kN == 1) return v; - if (kN == 2) { - return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 0, 1))}); - } - if (kN == 4) { - return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); - } +// ------------------------------ PromoteTo signed (ShiftRight, ZipLower) -#if HWY_TARGET == HWY_SSE2 - const VU rev4{ - _mm_shufflehi_epi16(_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), - _MM_SHUFFLE(0, 1, 2, 3))}; - return BitCast(d, VU{_mm_shuffle_epi32(rev4.raw, _MM_SHUFFLE(1, 0, 3, 2))}); +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<8>(VFromD{_mm_unpacklo_epi8(v.raw, v.raw)}); #else - const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100}; - return BitCast(d, TableLookupBytes(v, LoadDup128(di, kShuffle))); + return VFromD{_mm_cvtepi8_epi16(v.raw)}; #endif } - -template -HWY_API VFromD Reverse(D d, const VFromD v) { - constexpr int kN = static_cast(MaxLanes(d)); - if (kN == 1) return v; -#if HWY_TARGET <= HWY_SSSE3 - // NOTE: Lanes with negative shuffle control mask values are set to zero. - alignas(16) static constexpr int8_t kReverse[16] = { - kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, - kN - 9, kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16}; - const RebindToSigned di; - const VFromD idx = Load(di, kReverse); - return VFromD{_mm_shuffle_epi8(BitCast(di, v).raw, idx.raw)}; +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<16>(VFromD{_mm_unpacklo_epi16(v.raw, v.raw)}); #else - const RepartitionToWide d16; - return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); + return VFromD{_mm_cvtepi16_epi32(v.raw)}; #endif } - -// ------------------------------ Reverse2 - -// Single lane: no change -template -HWY_API VFromD Reverse2(D /* tag */, VFromD v) { - return v; +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<32>(VFromD{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi32_epi64(v.raw)}; +#endif } - -// Generic for all vector lengths (128-bit sufficient if SSE2). -template -HWY_API VFromD Reverse2(D d, VFromD v) { -#if HWY_TARGET <= HWY_AVX3 - const Repartition du32; - return BitCast(d, RotateRight<16>(BitCast(du32, v))); -#elif HWY_TARGET == HWY_SSE2 - const RebindToUnsigned du; - using VU = VFromD; - const VU vu = BitCast(du, v); // for float16_t - constexpr size_t kN = MaxLanes(d); - __m128i shuf_result = _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(2, 3, 0, 1)); - if (kN > 4) { - shuf_result = _mm_shufflehi_epi16(shuf_result, _MM_SHUFFLE(2, 3, 0, 1)); - } - return BitCast(d, VU{shuf_result}); +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(VFromD{x4}); #else - const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0302, 0x0100, 0x0706, 0x0504, 0x0B0A, 0x0908, 0x0F0E, 0x0D0C}; - return BitCast(d, TableLookupBytes(v, LoadDup128(di, kShuffle))); + return VFromD{_mm_cvtepi8_epi32(v.raw)}; #endif } - -// Generic for all vector lengths. -template -HWY_API VFromD Reverse2(D /* tag */, VFromD v) { - return Shuffle2301(v); +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Repartition di32; + const Half dh_i32; + const VFromD x4{PromoteTo(dh_i32, v).raw}; + const VFromD s4{ + _mm_shufflelo_epi16(x4.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + return ZipLower(d, x4, s4); +#else + (void)d; + return VFromD{_mm_cvtepi8_epi64(v.raw)}; +#endif } - -// Generic for all vector lengths. -template -HWY_API VFromD Reverse2(D /* tag */, VFromD v) { - return Shuffle01(v); +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Repartition di32; + const Half dh_i32; + const VFromD x2{PromoteTo(dh_i32, v).raw}; + const VFromD s2{ + _mm_shufflelo_epi16(x2.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + return ZipLower(d, x2, s2); +#else + (void)d; + return VFromD{_mm_cvtepi16_epi64(v.raw)}; +#endif } -// ------------------------------ Reverse4 - -template -HWY_API VFromD Reverse4(D d, VFromD v) { - const RebindToUnsigned du; - using VU = VFromD; - const VU vu = BitCast(du, v); // for float16_t - // 4x 16-bit: a single shufflelo suffices. - constexpr size_t kN = MaxLanes(d); - if (kN <= 4) { - return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); - } +// -------------------- PromoteTo float (ShiftLeft, IfNegativeThenElse) +#if HWY_TARGET < HWY_SSE4 && !defined(HWY_DISABLE_F16C) -#if HWY_TARGET == HWY_SSE2 - return BitCast(d, VU{_mm_shufflehi_epi16( - _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), - _MM_SHUFFLE(0, 1, 2, 3))}); +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C #else - const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908}; - return BitCast(d, TableLookupBytes(v, LoadDup128(di, kShuffle))); +#define HWY_NATIVE_F16C #endif -} -// Generic for all vector lengths. -template -HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { - return Shuffle0123(v); +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 VFromD PromoteTo(D /*tag*/, VFromD> v) { +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm_cvtph_ps(v.raw)}; +#endif } -template -HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { - HWY_ASSERT(0); // don't have 4 u64 lanes -} +#endif // HWY_NATIVE_F16C -// ------------------------------ Reverse8 +#if HWY_HAVE_FLOAT16 -template -HWY_API VFromD Reverse8(D d, const VFromD v) { -#if HWY_TARGET == HWY_SSE2 - const RepartitionToWide dw; - return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 #else - const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100}; - return BitCast(d, TableLookupBytes(v, LoadDup128(di, kShuffle))); +#define HWY_NATIVE_PROMOTE_F16_TO_F64 #endif -} -template -HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { - HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, VFromD> v) { + return VFromD{_mm_cvtph_pd(v.raw)}; } -// ------------------------------ ReverseBits in x86_512 - -// ------------------------------ InterleaveUpper (UpperHalf) +#endif // HWY_HAVE_FLOAT16 -// Full -template -HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { - return VFromD{_mm_unpackhi_epi8(a.raw, b.raw)}; -} -template -HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { - const DFromV d; - const RebindToUnsigned du; - using VU = VFromD; // for float16_t - return BitCast( - d, VU{_mm_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); -} -template -HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { - return VFromD{_mm_unpackhi_epi32(a.raw, b.raw)}; -} -template -HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { - return VFromD{_mm_unpackhi_epi64(a.raw, b.raw)}; -} -template -HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { - return VFromD{_mm_unpackhi_ps(a.raw, b.raw)}; -} -template -HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { - return VFromD{_mm_unpackhi_pd(a.raw, b.raw)}; +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } -// Partial -template -HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { - const Half d2; - return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, - VFromD{UpperHalf(d2, b).raw}); +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtps_pd(v.raw)}; } -// -------------------------- I8/U8 Broadcast (InterleaveLower, InterleaveUpper) - -template -HWY_API Vec128 Broadcast(const Vec128 v) { - static_assert(0 <= kLane && kLane < N, "Invalid lane"); - const DFromV d; - -#if HWY_TARGET == HWY_SSE2 - const Full128 d_full; - const Vec128 v_full{v.raw}; - const auto v_interleaved = (kLane < 8) - ? InterleaveLower(d_full, v_full, v_full) - : InterleaveUpper(d_full, v_full, v_full); - return ResizeBitCast( - d, Broadcast(BitCast(Full128(), v_interleaved))); -#else - return TableLookupBytes(v, Set(d, static_cast(kLane))); -#endif +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi32_pd(v.raw)}; } -// ------------------------------ ZipLower/ZipUpper (InterleaveLower) - -// Same as Interleave*, except that the return lanes are double-width integers; -// this is necessary because the single-lane scalar cannot return two values. -// Generic for all vector lengths. -template >> -HWY_API VFromD ZipLower(V a, V b) { - return BitCast(DW(), InterleaveLower(a, b)); -} -template , class DW = RepartitionToWide> -HWY_API VFromD ZipLower(DW dw, V a, V b) { - return BitCast(dw, InterleaveLower(D(), a, b)); +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteTo(D /*df64*/, VFromD> v) { + return VFromD{_mm_cvtepu32_pd(v.raw)}; } - -template , class DW = RepartitionToWide> -HWY_API VFromD ZipUpper(DW dw, V a, V b) { - return BitCast(dw, InterleaveUpper(D(), a, b)); +#else +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + const Rebind di32; + const auto i32_to_f64_result = PromoteTo(df64, BitCast(di32, v)); + return i32_to_f64_result + IfNegativeThenElse(i32_to_f64_result, + Set(df64, 4294967296.0), + Zero(df64)); } +#endif // HWY_TARGET <= HWY_AVX3 // ------------------------------ Per4LaneBlockShuffle namespace detail { @@ -5758,7 +7535,11 @@ template HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, hwy::SizeTag<8> /*vect_size_tag*/, V v) { - return V{_mm_shufflelo_epi16(v.raw, static_cast(kIdx3210 & 0xFF))}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, static_cast(kIdx3210 & 0xFF))}); } #if HWY_TARGET == HWY_SSE2 @@ -5766,8 +7547,12 @@ template HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, hwy::SizeTag<16> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t constexpr int kShuffle = static_cast(kIdx3210 & 0xFF); - return V{_mm_shufflehi_epi16(_mm_shufflelo_epi16(v.raw, kShuffle), kShuffle)}; + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, kShuffle), kShuffle)}); } template HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, size_t max_lanes_to_store) { - const size_t num_of_lanes_to_store = + const size_t num_lanes_to_store = HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)); #if HWY_COMPILER_MSVC @@ -6181,12 +7966,14 @@ HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, HWY_FENCE; #endif - BlendedStore(v, FirstN(d, num_of_lanes_to_store), d, p); + BlendedStore(v, FirstN(d, num_lanes_to_store), d, p); #if HWY_COMPILER_MSVC // Work around MSVC compiler bug by using a HWY_FENCE after the BlendedStore HWY_FENCE; #endif + + detail::MaybeUnpoison(p, num_lanes_to_store); } #if HWY_TARGET > HWY_AVX3 @@ -6214,36 +8001,35 @@ namespace detail { template HWY_API void AVX2UIF8Or16StoreTrailingN(VFromD v_trailing, D /*d*/, TFromD* HWY_RESTRICT p, - size_t num_of_lanes_to_store) { + size_t num_lanes_to_store) { // AVX2UIF8Or16StoreTrailingN should only be called for an I8/U8 vector if - // (num_of_lanes_to_store & 3) != 0 is true + // (num_lanes_to_store & 3) != 0 is true const auto v_full128 = ResizeBitCast(Full128>(), v_trailing); - if ((num_of_lanes_to_store & 2) != 0) { + if ((num_lanes_to_store & 2) != 0) { const uint16_t u16_bits = GetLane(BitCast(Full128(), v_full128)); - p[num_of_lanes_to_store - 1] = detail::ExtractLane<2>(v_full128); + p[num_lanes_to_store - 1] = detail::ExtractLane<2>(v_full128); CopyBytes(&u16_bits, - p + (num_of_lanes_to_store & ~size_t{3})); + p + (num_lanes_to_store & ~size_t{3})); } else { - p[num_of_lanes_to_store - 1] = GetLane(v_full128); + p[num_lanes_to_store - 1] = GetLane(v_full128); } } template HWY_API void AVX2UIF8Or16StoreTrailingN(VFromD v_trailing, D /*d*/, - TFromD* HWY_RESTRICT p, - size_t num_of_lanes_to_store) { + TFromD* p, + size_t num_lanes_to_store) { // AVX2UIF8Or16StoreTrailingN should only be called for an I16/U16/F16/BF16 - // vector if (num_of_lanes_to_store & 1) == 1 is true - p[num_of_lanes_to_store - 1] = GetLane(v_trailing); + // vector if (num_lanes_to_store & 1) == 1 is true + p[num_lanes_to_store - 1] = GetLane(v_trailing); } } // namespace detail template -HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, - size_t max_lanes_to_store) { - const size_t num_of_lanes_to_store = +HWY_API void StoreN(VFromD v, D d, TFromD* p, size_t max_lanes_to_store) { + const size_t num_lanes_to_store = HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)); const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> @@ -6252,7 +8038,7 @@ HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, const Repartition di32_full; const auto i32_store_mask = BitCast( - di32_full, VecFromMask(du_full, FirstN(du_full, num_of_lanes_to_store))); + di32_full, VecFromMask(du_full, FirstN(du_full, num_lanes_to_store))); const auto vi32 = ResizeBitCast(di32_full, v); #if HWY_COMPILER_MSVC @@ -6265,19 +8051,21 @@ HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, constexpr size_t kNumOfLanesPerI32 = 4 / sizeof(TFromD); constexpr size_t kTrailingLenMask = kNumOfLanesPerI32 - 1; - const size_t trailing_n = (num_of_lanes_to_store & kTrailingLenMask); + const size_t trailing_n = (num_lanes_to_store & kTrailingLenMask); if (trailing_n != 0) { - const auto v_trailing = ResizeBitCast( + const VFromD v_trailing = ResizeBitCast( d, SlideDownLanes(di32_full, vi32, - num_of_lanes_to_store / kNumOfLanesPerI32)); - detail::AVX2UIF8Or16StoreTrailingN(v_trailing, d, p, num_of_lanes_to_store); + num_lanes_to_store / kNumOfLanesPerI32)); + detail::AVX2UIF8Or16StoreTrailingN(v_trailing, d, p, num_lanes_to_store); } #if HWY_COMPILER_MSVC // Work around MSVC compiler bug by using a HWY_FENCE after the BlendedStore HWY_FENCE; #endif + + detail::MaybeUnpoison(p, num_lanes_to_store); } #endif // HWY_TARGET > HWY_AVX3 #endif // HWY_TARGET <= HWY_AVX2 @@ -6300,19 +8088,36 @@ HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { // ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) -template +template HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { const RebindToUnsigned du; const Half duh; return BitCast(d, VFromD{_mm_move_epi64(BitCast(duh, lo).raw)}); } -template +template HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { const Half dh; return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); } +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, ZeroExtendVector(du, BitCast(duh, lo))); +} +#endif + +// Generic for all vector lengths. +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, ZeroExtendVector(du, BitCast(duh, lo))); +} + // ------------------------------ Concat full (InterleaveLower) // hiH,hiL loH,loL |-> hiL,loL (= lower halves) @@ -6459,10 +8264,11 @@ template HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns // 0xFFFF8000, which correctly saturates to 0x8000. + const RebindToUnsigned du; const Repartition dw; const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); - return VFromD{_mm_packs_epi32(uL.raw, uH.raw)}; + return BitCast(d, VFromD{_mm_packs_epi32(uL.raw, uH.raw)}); } // 16-bit x4 @@ -6565,11 +8371,12 @@ template HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { #if HWY_TARGET <= HWY_SSE4 // Isolate lower 16 bits per u32 so we can pack. + const RebindToUnsigned du; // for float16_t const Repartition dw; const Vec128 mask = Set(dw, 0x0000FFFF); const Vec128 uH = And(BitCast(dw, hi), mask); const Vec128 uL = And(BitCast(dw, lo), mask); - return VFromD{_mm_packus_epi32(uL.raw, uH.raw)}; + return BitCast(d, VFromD{_mm_packus_epi32(uL.raw, uH.raw)}); #elif HWY_TARGET == HWY_SSE2 const Repartition dw; return ConcatOdd(d, BitCast(d, ShiftLeft<16>(BitCast(dw, hi))), @@ -6642,9 +8449,9 @@ HWY_API V DupEven(V v) { #if HWY_TARGET <= HWY_SSSE3 const RebindToUnsigned du; - alignas(16) static constexpr uint8_t kShuffle[16] = { - 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; - return TableLookupBytes(v, BitCast(d, LoadDup128(du, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + du, 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14); + return TableLookupBytes(v, BitCast(d, shuffle)); #else const Repartition du16; return IfVecThenElse(BitCast(d, Set(du16, uint16_t{0xFF00})), @@ -6656,8 +8463,8 @@ template HWY_API Vec64 DupEven(const Vec64 v) { const DFromV d; const RebindToUnsigned du; // for float16_t - return BitCast(d, VFromD{ - _mm_shufflelo_epi16(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}); + return BitCast(d, VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, _MM_SHUFFLE(2, 2, 0, 0))}); } // Generic for all vector lengths. @@ -6666,9 +8473,9 @@ HWY_API V DupEven(const V v) { const DFromV d; const RebindToUnsigned du; // for float16_t #if HWY_TARGET <= HWY_SSSE3 - alignas(16) static constexpr uint16_t kShuffle[8] = { - 0x0100, 0x0100, 0x0504, 0x0504, 0x0908, 0x0908, 0x0d0c, 0x0d0c}; - return TableLookupBytes(v, BitCast(d, LoadDup128(du, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + du, 0x0100, 0x0100, 0x0504, 0x0504, 0x0908, 0x0908, 0x0d0c, 0x0d0c); + return TableLookupBytes(v, BitCast(d, shuffle)); #else return BitCast( d, VFromD{_mm_shufflehi_epi16( @@ -6699,9 +8506,9 @@ HWY_API V DupOdd(V v) { #if HWY_TARGET <= HWY_SSSE3 const RebindToUnsigned du; - alignas(16) static constexpr uint8_t kShuffle[16] = { - 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}; - return TableLookupBytes(v, BitCast(d, LoadDup128(du, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + du, 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15); + return TableLookupBytes(v, BitCast(d, shuffle)); #else const Repartition du16; return IfVecThenElse(BitCast(d, Set(du16, uint16_t{0x00FF})), @@ -6723,9 +8530,9 @@ HWY_API V DupOdd(V v) { const DFromV d; const RebindToUnsigned du; // for float16_t #if HWY_TARGET <= HWY_SSSE3 - alignas(16) static constexpr uint16_t kShuffle[8] = { - 0x0302, 0x0302, 0x0706, 0x0706, 0x0b0a, 0x0b0a, 0x0f0e, 0x0f0e}; - return TableLookupBytes(v, BitCast(d, LoadDup128(du, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + du, 0x0302, 0x0302, 0x0706, 0x0706, 0x0b0a, 0x0b0a, 0x0f0e, 0x0f0e); + return TableLookupBytes(v, BitCast(d, shuffle)); #else return BitCast( d, VFromD{_mm_shufflehi_epi16( @@ -6952,14 +8759,16 @@ HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { template HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { -#if HWY_TARGET >= HWY_SSSE3 const DFromV d; +#if HWY_TARGET >= HWY_SSSE3 const Repartition d8; alignas(16) static constexpr uint8_t mask[16] = { 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); #else - return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x55)}; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_blend_epi16( + BitCast(du, a).raw, BitCast(du, b).raw, 0x55)}); #endif } @@ -7008,6 +8817,94 @@ HWY_API Vec128 OddEven(Vec128 a, Vec128 b) { #endif } +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + return ConcatEven(d, b, a); +} + +// I8/U8 InterleaveEven is generic for all vector lengths that are >= 4 bytes +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Repartition du16; + return OddEven(BitCast(d, ShiftLeft<8>(BitCast(du16, b))), a); +} + +// I16/U16 InterleaveEven is generic for all vector lengths that are >= 8 bytes +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Repartition du32; + return OddEven(BitCast(d, ShiftLeft<16>(BitCast(du32, b))), a); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_epi32( + a.raw, static_cast<__mmask8>(0x0A), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0x0A), + b.raw, b.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +#else +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const auto b2_b0_a2_a0 = ConcatEven(df, BitCast(df, b), BitCast(df, a)); + return BitCast( + d, VFromD{_mm_shuffle_ps(b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, + _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return ConcatOdd(d, b, a); +} + +// I8/U8 InterleaveOdd is generic for all vector lengths that are >= 4 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Repartition du16; + return OddEven(b, BitCast(d, ShiftRight<8>(BitCast(du16, a)))); +} + +// I16/U16 InterleaveOdd is generic for all vector lengths that are >= 8 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Repartition du32; + return OddEven(b, BitCast(d, ShiftRight<16>(BitCast(du32, a)))); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_epi32( + b.raw, static_cast<__mmask8>(0x05), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x05), + a.raw, a.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +#else +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const auto b3_b1_a3_a1 = ConcatOdd(df, BitCast(df, b), BitCast(df, a)); + return BitCast( + d, VFromD{_mm_shuffle_ps(b3_b1_a3_a1.raw, b3_b1_a3_a1.raw, + _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + // ------------------------------ OddEvenBlocks template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { @@ -7028,6 +8925,7 @@ HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { // to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. namespace detail { + #if HWY_TARGET == HWY_AVX2 // Unused for AVX3 - we use sllv directly template HWY_API V AVX2ShlU16Vec128(V v, V bits) { @@ -7036,6 +8934,22 @@ HWY_API V AVX2ShlU16Vec128(V v, V bits) { return TruncateTo(d, PromoteTo(du32, v) << PromoteTo(du32, bits)); } #elif HWY_TARGET > HWY_AVX2 + +template +static HWY_INLINE VFromD Pow2ConvF32ToI32( + D32 d32, VFromD> vf32) { + const RebindToSigned di32; +#if HWY_COMPILER_GCC_ACTUAL + // ConvertInRangeTo is safe with GCC due the inline assembly workaround used + // for F32->I32 ConvertInRangeTo with GCC + return BitCast(d32, ConvertInRangeTo(di32, vf32)); +#else + // Otherwise, use NearestIntInRange because we rely on the native 0x80..00 + // overflow behavior + return BitCast(d32, NearestIntInRange(di32, vf32)); +#endif +} + // Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. template HWY_INLINE Vec128> Pow2(const Vec128 v) { @@ -7051,8 +8965,8 @@ HWY_INLINE Vec128> Pow2(const Vec128 v) { const auto f0 = ZipLower(dw, zero, upper); const auto f1 = ZipUpper(dw, zero, upper); // See cvtps comment below. - const VFromD bits0{_mm_cvtps_epi32(BitCast(df, f0).raw)}; - const VFromD bits1{_mm_cvtps_epi32(BitCast(df, f1).raw)}; + const VFromD bits0 = Pow2ConvF32ToI32(dw, BitCast(df, f0)); + const VFromD bits1 = Pow2ConvF32ToI32(dw, BitCast(df, f1)); #if HWY_TARGET <= HWY_SSE4 return VFromD{_mm_packus_epi32(bits0.raw, bits1.raw)}; #else @@ -7073,7 +8987,8 @@ HWY_INLINE Vec128, N> Pow2(const Vec128 v) { // Insert 0 into lower halves for reinterpreting as binary32. const auto f0 = ZipLower(dt_w, Zero(dt_u), ResizeBitCast(dt_u, upper)); // See cvtps comment below. - const VFromD bits0{_mm_cvtps_epi32(BitCast(dt_f, f0).raw)}; + const VFromD bits0 = + Pow2ConvF32ToI32(dt_w, BitCast(dt_f, f0)); #if HWY_TARGET <= HWY_SSE4 return VFromD{_mm_packus_epi32(bits0.raw, bits0.raw)}; #elif HWY_TARGET == HWY_SSSE3 @@ -7091,11 +9006,12 @@ HWY_INLINE Vec128, N> Pow2(const Vec128 v) { template HWY_INLINE Vec128, N> Pow2(const Vec128 v) { const DFromV d; + const RebindToFloat df; const auto exp = ShiftLeft<23>(v); const auto f = exp + Set(d, 0x3F800000); // 1.0f // Do not use ConvertTo because we rely on the native 0x80..00 overflow // behavior. - return Vec128, N>{_mm_cvtps_epi32(_mm_castsi128_ps(f.raw))}; + return Pow2ConvF32ToI32(d, BitCast(df, f)); } #endif // HWY_TARGET > HWY_AVX2 @@ -7571,48 +9487,180 @@ HWY_API Vec128 operator>>(Vec128 v, const DFromV d; return detail::SignedShr(d, v, bits); #endif -} +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +namespace detail { + +template )> +static HWY_INLINE V SSE2Mul128(V a, V b, V& mulH) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.txt + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + mulH = MulEven(aH, bH) + w1 + k; + return ShiftLeft<32>(t) + w3; +} + +template )> +static HWY_INLINE V SSE2Mul128(V a, V b, V& mulH) { + const DFromV di64; + const RebindToUnsigned du64; + using VU64 = VFromD; + + VU64 unsigned_mulH; + const auto mulL = BitCast( + di64, SSE2Mul128(BitCast(du64, a), BitCast(du64, b), unsigned_mulH)); + mulH = BitCast(di64, unsigned_mulH) - And(BroadcastSignBit(a), b) - + And(a, BroadcastSignBit(b)); + return mulL; +} + +} // namespace detail + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 16 : 8))> +HWY_API V MulEven(V a, V b) { + V mulH; + const V mulL = detail::SSE2Mul128(a, b, mulH); + return InterleaveLower(mulL, mulH); +} + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 16 : 8))> +HWY_API V MulOdd(V a, V b) { + const DFromV du64; + V mulH; + const V mulL = detail::SSE2Mul128(a, b, mulH); + return InterleaveUpper(du64, mulL, mulH); +} + +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 8 : 0))> +HWY_API V MulHigh(V a, V b) { + V mulH; + detail::SSE2Mul128(a, b, mulH); + return mulH; +} + +#if HWY_ARCH_X86_64 + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(d, mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + const Half d2; + alignas(16) T mul[2]; + const T a1 = GetLane(UpperHalf(d2, a)); + const T b1 = GetLane(UpperHalf(d2, b)); + mul[0] = Mul128(a1, b1, &mul[1]); + return Load(d, mul); +} + +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Vec64{_mm_cvtsi64_si128(static_cast(hi))}; +} + +#endif // HWY_ARCH_X86_64 + +// ================================================== CONVERT (2) + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec64 v) { + return PromoteLowerTo(d_to, v); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec128 v) { + const Repartition d_from; + return PromoteLowerTo(d_to, ConcatEven(d_from, v, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + V v) { + const Repartition d_from; + return PromoteLowerTo(d_to, ConcatOdd(d_from, v, v)); +} + +} // namespace detail +#endif -// ------------------------------ MulEven/Odd 64x64 (UpperHalf) +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" -HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { - const DFromV d; - alignas(16) uint64_t mul[2]; - mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); - return Load(d, mul); -} +// ------------------------------ WidenMulPairwiseAdd (PromoteEvenTo) -HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { - const DFromV d; - const Half d2; - alignas(16) uint64_t mul[2]; - const uint64_t a1 = GetLane(UpperHalf(d2, a)); - const uint64_t b1 = GetLane(UpperHalf(d2, b)); - mul[0] = Mul128(a1, b1, &mul[1]); - return Load(d, mul); +#if HWY_NATIVE_DOT_BF16 + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m128bh>(a.raw), + reinterpret_cast<__m128bh>(b.raw))}; } -// ------------------------------ WidenMulPairwiseAdd +#else // Generic for all vector lengths. -template >> -HWY_API VFromD WidenMulPairwiseAdd(D32 df32, V16 a, V16 b) { - // TODO(janwas): _mm_dpbf16_ps when available - const RebindToUnsigned du32; - // Lane order within sum0/1 is undefined, hence we can avoid the - // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip - // leads to the odd/even order that RearrangeToOddPlusEven prefers. - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), - Mul(BitCast(df32, ao), BitCast(df32, bo))); +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); } +#endif // HWY_NATIVE_DOT_BF16 + // Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. template >> @@ -7654,331 +9702,149 @@ HWY_API VFromD SatWidenMulPairwiseAdd( #endif -// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ShiftLeft) - -// Generic for all vector lengths. -template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, - const VFromD sum0, - VFromD& sum1) { - // TODO(janwas): _mm_dpbf16_ps when available - const RebindToUnsigned du32; - // Lane order within sum0/1 is undefined, hence we can avoid the - // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip - // leads to the odd/even order that RearrangeToOddPlusEven prefers. - using VU32 = VFromD; - const VU32 odd = Set(du32, 0xFFFF0000u); - const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); - const VU32 ao = And(BitCast(du32, a), odd); - const VU32 be = ShiftLeft<16>(BitCast(du32, b)); - const VU32 bo = And(BitCast(du32, b), odd); - sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); - return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); -} - -// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. -template >> -HWY_API VFromD ReorderWidenMulAccumulate(D32 d, V16 a, V16 b, - const VFromD sum0, - VFromD& /*sum1*/) { - (void)d; -#if HWY_TARGET <= HWY_AVX3_DL - return VFromD{_mm_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; -#else - return sum0 + WidenMulPairwiseAdd(d, a, b); -#endif -} - -template >> -HWY_API VFromD ReorderWidenMulAccumulate(DU32 d, VU16 a, VU16 b, - const VFromD sum0, - VFromD& /*sum1*/) { - (void)d; - return sum0 + WidenMulPairwiseAdd(d, a, b); -} - -// ------------------------------ RearrangeToOddPlusEven -template -HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, - Vec128 /*sum1*/) { - return sum0; // invariant already holds -} - -template -HWY_API Vec128 RearrangeToOddPlusEven( - const Vec128 sum0, Vec128 /*sum1*/) { - return sum0; // invariant already holds -} - -template -HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { - return Add(sum0, sum1); -} +// ------------------------------ SatWidenMulPairwiseAccumulate -// ------------------------------ SumOfMulQuadAccumulate #if HWY_TARGET <= HWY_AVX3_DL -#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE -#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM #else -#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM #endif +// Even if N=1, the I16 vectors have at least 2 lanes, hence _mm_dpwssds_epi32 +// is safe. template -HWY_API VFromD SumOfMulQuadAccumulate( - DI32 /*di32*/, VFromD> a_u, - VFromD> b_i, VFromD sum) { - return VFromD{_mm_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; -} - -#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE -#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE -#else -#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE -#endif -template -HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, - VFromD> a, - VFromD> b, - VFromD sum) { - // TODO(janwas): AVX-VNNI-INT8 has dpbssd. - const Repartition du8; - - const auto a_u = BitCast(du8, a); - const auto result_sum_0 = SumOfMulQuadAccumulate(di32, a_u, b, sum); - const auto result_sum_1 = ShiftLeft<8>( - SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32))); - return result_sum_0 - result_sum_1; -} - -#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE -#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE -#else -#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE -#endif -template -HWY_API VFromD SumOfMulQuadAccumulate( - DU32 du32, VFromD> a, - VFromD> b, VFromD sum) { - // TODO(janwas): AVX-VNNI-INT8 has dpbuud. - const Repartition du8; - const RebindToSigned di8; - const RebindToSigned di32; - - const auto b_i = BitCast(di8, b); - const auto result_sum_0 = - SumOfMulQuadAccumulate(di32, a, b_i, BitCast(di32, sum)); - const auto result_sum_1 = ShiftLeft<8>( - SumOfMulQuadAccumulate(di32, a, BroadcastSignBit(b_i), Zero(di32))); - - return BitCast(du32, result_sum_0 - result_sum_1); +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm_dpwssds_epi32(sum.raw, a.raw, b.raw)}; } #endif // HWY_TARGET <= HWY_AVX3_DL -// ================================================== CONVERT - -// ------------------------------ Promotions (part w/ narrow lanes -> full) - -// Unsigned: zero-extend. -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - const __m128i zero = _mm_setzero_si128(); - return VFromD{_mm_unpacklo_epi8(v.raw, zero)}; -#else - return VFromD{_mm_cvtepu8_epi16(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - return VFromD{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; -#else - return VFromD{_mm_cvtepu16_epi32(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - return VFromD{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; -#else - return VFromD{_mm_cvtepu32_epi64(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - const __m128i zero = _mm_setzero_si128(); - const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); - return VFromD{_mm_unpacklo_epi16(u16, zero)}; -#else - return VFromD{_mm_cvtepu8_epi32(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D d, VFromD> v) { -#if HWY_TARGET > HWY_SSSE3 - const Rebind du32; - return PromoteTo(d, PromoteTo(du32, v)); -#elif HWY_TARGET == HWY_SSSE3 - alignas(16) static constexpr int8_t kShuffle[16] = { - 0, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1}; - const Repartition di8; - return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); -#else - (void)d; - return VFromD{_mm_cvtepu8_epi64(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D d, VFromD> v) { -#if HWY_TARGET > HWY_SSSE3 - const Rebind du32; - return PromoteTo(d, PromoteTo(du32, v)); -#elif HWY_TARGET == HWY_SSSE3 - alignas(16) static constexpr int8_t kShuffle[16] = { - 0, 1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1, -1, -1}; - const Repartition di8; - return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); -#else - (void)d; - return VFromD{_mm_cvtepu16_epi64(v.raw)}; -#endif -} - -// Unsigned to signed: same plus cast. -template ), sizeof(TFromV)), - HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V))> -HWY_API VFromD PromoteTo(D di, V v) { - const RebindToUnsigned du; - return BitCast(di, PromoteTo(du, v)); -} - -// Signed: replicate sign bit. -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - return ShiftRight<8>(VFromD{_mm_unpacklo_epi8(v.raw, v.raw)}); -#else - return VFromD{_mm_cvtepi8_epi16(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - return ShiftRight<16>(VFromD{_mm_unpacklo_epi16(v.raw, v.raw)}); -#else - return VFromD{_mm_cvtepi16_epi32(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - return ShiftRight<32>(VFromD{_mm_unpacklo_epi32(v.raw, v.raw)}); -#else - return VFromD{_mm_cvtepi32_epi64(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); - const __m128i x4 = _mm_unpacklo_epi16(x2, x2); - return ShiftRight<24>(VFromD{x4}); -#else - return VFromD{_mm_cvtepi8_epi32(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D d, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - const Repartition di32; - const Half dh_i32; - const VFromD x4{PromoteTo(dh_i32, v).raw}; - const VFromD s4{ - _mm_shufflelo_epi16(x4.raw, _MM_SHUFFLE(3, 3, 1, 1))}; - return ZipLower(d, x4, s4); -#else - (void)d; - return VFromD{_mm_cvtepi8_epi64(v.raw)}; -#endif -} -template -HWY_API VFromD PromoteTo(D d, VFromD> v) { -#if HWY_TARGET >= HWY_SSSE3 - const Repartition di32; - const Half dh_i32; - const VFromD x2{PromoteTo(dh_i32, v).raw}; - const VFromD s2{ - _mm_shufflelo_epi16(x2.raw, _MM_SHUFFLE(3, 3, 1, 1))}; - return ZipLower(d, x2, s2); -#else - (void)d; - return VFromD{_mm_cvtepi16_epi64(v.raw)}; -#endif -} +// ------------------------------ ReorderWidenMulAccumulate (PromoteEvenTo) -#if HWY_TARGET < HWY_SSE4 && !defined(HWY_DISABLE_F16C) +#if HWY_NATIVE_DOT_BF16 -// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. -#ifdef HWY_NATIVE_F16C -#undef HWY_NATIVE_F16C +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 #else -#define HWY_NATIVE_F16C +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 #endif -// Workaround for origin tracking bug in Clang msan prior to 11.0 -// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") -#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) -#define HWY_INLINE_F16 HWY_NOINLINE +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm_dpbf16_ps(sum0.raw, reinterpret_cast<__m128bh>(a.raw), + reinterpret_cast<__m128bh>(b.raw))}; +} + +#endif // HWY_NATIVE_DOT_BF16 + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 d, V16 a, V16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; #else -#define HWY_INLINE_F16 HWY_INLINE + return sum0 + WidenMulPairwiseAdd(d, a, b); #endif -template -HWY_INLINE_F16 VFromD PromoteTo(D /*tag*/, VFromD> v) { - return VFromD{_mm_cvtph_ps(v.raw)}; } -#endif // HWY_NATIVE_F16C +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DU32 d, VU16 a, VU16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; + return sum0 + WidenMulPairwiseAdd(d, a, b); +} -template -HWY_API VFromD PromoteTo(D df32, VFromD> v) { - const Rebind du16; - const RebindToSigned di32; - return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + Vec128 /*sum1*/) { + return sum0; // invariant already holds } -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { - return VFromD{_mm_cvtps_pd(v.raw)}; +template +HWY_API Vec128 RearrangeToOddPlusEven( + const Vec128 sum0, Vec128 /*sum1*/) { + return sum0; // invariant already holds } -template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { - return VFromD{_mm_cvtepi32_pd(v.raw)}; +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); } -#if HWY_TARGET <= HWY_AVX3 -template -HWY_API VFromD PromoteTo(D /*df64*/, VFromD> v) { - return VFromD{_mm_cvtepu32_pd(v.raw)}; +// ------------------------------ SumOfMulQuadAccumulate +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; } + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE #else -// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 -template -HWY_API VFromD PromoteTo(D df64, VFromD> v) { - const Rebind di32; - const auto i32_to_f64_result = PromoteTo(df64, BitCast(di32, v)); - return i32_to_f64_result + IfNegativeThenElse(i32_to_f64_result, - Set(df64, 4294967296.0), - Zero(df64)); +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + // TODO(janwas): AVX-VNNI-INT8 has dpbssd. + const Repartition du8; + + const auto a_u = BitCast(du8, a); + const auto result_sum_0 = SumOfMulQuadAccumulate(di32, a_u, b, sum); + const auto result_sum_1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32))); + return result_sum_0 - result_sum_1; } + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE #endif +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 du32, VFromD> a, + VFromD> b, VFromD sum) { + // TODO(janwas): AVX-VNNI-INT8 has dpbuud. + const Repartition du8; + const RebindToSigned di8; + const RebindToSigned di32; + + const auto b_i = BitCast(di8, b); + const auto result_sum_0 = + SumOfMulQuadAccumulate(di32, a, b_i, BitCast(di32, sum)); + const auto result_sum_1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(di32, a, BroadcastSignBit(b_i), Zero(di32))); + + return BitCast(du32, result_sum_0 - result_sum_1); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL // ------------------------------ Demotions (full -> part w/ narrow lanes) @@ -8143,33 +10009,93 @@ HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") template -HWY_API VFromD DemoteTo(D /*tag*/, VFromD> v) { - return VFromD{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); } HWY_DIAGNOSTICS(pop) #endif // F16C +#if HWY_HAVE_FLOAT16 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API VFromD DemoteTo(D /*df16*/, VFromD> v) { + return VFromD{_mm_cvtpd_ph(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +// The _mm*_cvtneps_pbh and _mm*_cvtne2ps_pbh intrinsics require GCC 9 or later +// or Clang 10 or later + +// Also need GCC or Clang to bit cast the __m128bh, __m256bh, or __m512bh vector +// returned by the _mm*_cvtneps_pbh and _mm*_cvtne2ps_pbh intrinsics to a +// __m128i, __m256i, or __m512i as there are currently no intrinsics available +// (as of GCC 13 and Clang 17) to bit cast a __m128bh, __m256bh, or __m512bh +// vector to a __m128i, __m256i, or __m512i vector + +#if HWY_AVX3_HAVE_F32_TO_BF16C +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + template -HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { - // TODO(janwas): _mm_cvtneps_pbh once we have avx512bf16. - const Rebind di32; - const Rebind du32; // for logical shift right - const Rebind du16; - const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); - return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec128 a, + Vec128 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm_cvtne2ps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32( + detail::BitCastToInteger(_mm_cvtne2ps_pbh(b.raw, a.raw)), + _MM_SHUFFLE(2, 0, 2, 0))}; } -template >> -HWY_API VFromD ReorderDemote2To(D dbf16, V32 a, V32 b) { - // TODO(janwas): _mm_cvtne2ps_pbh once we have avx512bf16. - const RebindToUnsigned du16; - const Repartition du32; - const VFromD b_in_even = ShiftRight<16>(BitCast(du32, b)); - return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +template +HWY_API VFromD ReorderDemote2To(D dbf16, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dbf16, Combine(dt, b, a)); } +#endif // HWY_AVX3_HAVE_F32_TO_BF16C // Specializations for partial vectors because packs_epi32 sets lanes above 2*N. template @@ -8328,11 +10254,15 @@ HWY_API VFromD OrderedDemote2To(D d, V a, V b) { return ReorderDemote2To(d, a, b); } -template >> -HWY_API VFromD OrderedDemote2To(D dbf16, V32 a, V32 b) { - const RebindToUnsigned du16; - return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); +#if HWY_AVX3_HAVE_F32_TO_BF16C +// F32 to BF16 OrderedDemote2To is generic for all vector lengths on targets +// that support AVX512BF16 +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + return ReorderDemote2To(dbf16, a, b); } +#endif // HWY_AVX3_HAVE_F32_TO_BF16C template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { @@ -8349,65 +10279,176 @@ HWY_INLINE VFromD ClampF64ToI32Max(D d, VFromD v) { return Min(v, Set(d, 2147483647.0)); } -// For ConvertTo float->int of same size, clamping before conversion would -// change the result because the max integer value is not exactly representable. -// Instead detect the overflow result after conversion and fix it. -// Generic for all vector lengths. -template -HWY_INLINE VFromD FixConversionOverflow(DI di, - VFromD> original, - VFromD converted) { - // Combinations of original and output sign: - // --: normal <0 or -huge_val to 80..00: OK - // -+: -0 to 0 : OK - // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF - // ++: normal >0 : OK - const VFromD sign_wrong = AndNot(BitCast(di, original), converted); -#if HWY_COMPILER_GCC_ACTUAL - // Critical GCC 11 compiler bug (possibly also GCC 10): omits the Xor; also - // Add() if using that instead. Work around with one more instruction. - const RebindToUnsigned du; - const VFromD mask = BroadcastSignBit(sign_wrong); - const VFromD max = BitCast(di, ShiftRight<1>(BitCast(du, mask))); - return IfVecThenElse(mask, max, converted); +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +template +static constexpr HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::FloatTag /* to_type_tag */, TF from_val) { + return ConvertScalarTo(from_val); +} + +template +static HWY_BITCASTSCALAR_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::SpecialTag /* to_type_tag */, TF from_val) { + return ConvertScalarTo(from_val); +} + +template +static HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::SignedTag /* to_type_tag */, TF from_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TFArith from_val_in_arith_type = ConvertScalarTo(from_val); + constexpr TTo kMinResultVal = LimitsMin(); + HWY_BITCASTSCALAR_CONSTEXPR const TFArith kMinOutOfRangePosVal = + ScalarAbs(ConvertScalarTo(kMinResultVal)); + + return (ScalarAbs(from_val_in_arith_type) < kMinOutOfRangePosVal) + ? ConvertScalarTo(from_val_in_arith_type) + : kMinResultVal; +} + +template +static HWY_CXX14_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::UnsignedTag /* to_type_tag */, TF from_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; #else - return Xor(converted, BroadcastSignBit(sign_wrong)); + using TFArith = If>; #endif + + const TFArith from_val_in_arith_type = ConvertScalarTo(from_val); + constexpr TTo kTToMsb = static_cast(TTo{1} << (sizeof(TTo) * 8 - 1)); + constexpr const TFArith kNegOne = ConvertScalarTo(-1.0); + constexpr const TFArith kMinOutOfRangePosVal = + ConvertScalarTo(static_cast(kTToMsb) * 2.0); + + return (from_val_in_arith_type > kNegOne && + from_val_in_arith_type < kMinOutOfRangePosVal) + ? ConvertScalarTo(from_val_in_arith_type) + : LimitsMax(); +} + +template +static constexpr HWY_INLINE HWY_MAYBE_UNUSED TTo +X86ConvertScalarFromFloat(TF from_val) { + return X86ConvertScalarFromFloat(hwy::TypeTag>(), + from_val); } +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD } // namespace detail -template > -HWY_API VFromD DemoteTo(D /* tag */, VFromD v) { - const VFromD clamped = detail::ClampF64ToI32Max(DF(), v); - return VFromD{_mm_cvttpd_epi32(clamped.raw)}; +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), int32_t{0}, + int32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("%vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epi32(v.raw)}; +#endif +} + +// F64 to I32 DemoteTo is generic for all vector lengths +template +HWY_API VFromD DemoteTo(D di32, VFromD> v) { + const Rebind df64; + const VFromD clamped = detail::ClampF64ToI32Max(df64, v); + return DemoteInRangeTo(di32, clamped); } -template -HWY_API VFromD DemoteTo(D du32, VFromD> v) { #if HWY_TARGET <= HWY_AVX3 - (void)du32; - return VFromD{ - _mm_maskz_cvttpd_epu32(_knot_mask8(MaskFromVec(v).raw), v.raw)}; -#else // AVX2 or earlier +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), uint32_t{0}, + uint32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttpd_epu32(v.raw)}; +#endif +} + +// F64->U32 DemoteTo is generic for all vector lengths +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return DemoteInRangeTo(D(), ZeroIfNegative(v)); +} +#else // HWY_TARGET > HWY_AVX3 + +// F64 to U32 DemoteInRangeTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteInRangeTo(D du32, VFromD> v) { + const RebindToSigned di32; const Rebind df64; const RebindToUnsigned du64; - // Clamp v[i] to a value between 0 and 4294967295 - const auto clamped = Min(ZeroIfNegative(v), Set(df64, 4294967295.0)); - const auto k2_31 = Set(df64, 2147483648.0); - const auto clamped_is_ge_k2_31 = (clamped >= k2_31); - const auto clamped_lo31_f64 = - clamped - IfThenElseZero(clamped_is_ge_k2_31, k2_31); - const VFromD clamped_lo31_u32{_mm_cvttpd_epi32(clamped_lo31_f64.raw)}; + const auto v_is_ge_k2_31 = (v >= k2_31); + const auto clamped_lo31_f64 = v - IfThenElseZero(v_is_ge_k2_31, k2_31); + const auto clamped_lo31_u32 = + BitCast(du32, DemoteInRangeTo(di32, clamped_lo31_f64)); const auto clamped_u32_msb = ShiftLeft<31>( - TruncateTo(du32, BitCast(du64, VecFromMask(df64, clamped_is_ge_k2_31)))); + TruncateTo(du32, BitCast(du64, VecFromMask(df64, v_is_ge_k2_31)))); return Or(clamped_lo31_u32, clamped_u32_msb); -#endif } +// F64 to U32 DemoteTo is generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D du32, VFromD> v) { + const Rebind df64; + const auto clamped = Min(ZeroIfNegative(v), Set(df64, 4294967295.0)); + return DemoteInRangeTo(du32, clamped); +} +#endif // HWY_TARGET <= HWY_AVX3 + #if HWY_TARGET <= HWY_AVX3 template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { @@ -8496,23 +10537,85 @@ HWY_API Vec128 U8FromU32(const Vec128 v) { } // ------------------------------ F32->UI64 PromoteTo +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + #if HWY_TARGET <= HWY_AVX3 template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttps_epi64(v.raw)}; +#endif +} + +// Generic for all vector lengths. +template HWY_API VFromD PromoteTo(D di64, VFromD> v) { const Rebind df32; const RebindToFloat df64; - const Twice dt_f32; - - return detail::FixConversionOverflow( - di64, - BitCast(df64, InterleaveLower(ResizeBitCast(dt_f32, v), - ResizeBitCast(dt_f32, v))), - VFromD{_mm_cvttps_epi64(v.raw)}); + // We now avoid GCC UB in PromoteInRangeTo via assembly, see #2189 and + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115115. Previously we fixed up + // the result afterwards using three instructions. Now we instead check if + // v >= 2^63, and if so replace the output with 2^63-1, which is likely more + // efficient. Note that the previous representable f32 is less than 2^63 and + // thus fits in i64. + const MFromD overflow = RebindMask( + di64, PromoteMaskTo(df64, df32, Ge(v, Set(df32, 9.223372e18f)))); + return IfThenElse(overflow, Set(di64, LimitsMax()), + PromoteInRangeTo(di64, v)); } -template +template HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { - return VFromD{ - _mm_maskz_cvttps_epu64(_knot_mask8(MaskFromVec(v).raw), v.raw)}; + return PromoteInRangeTo(D(), ZeroIfNegative(v)); +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttps_epu64(v.raw)}; +#endif } #else // AVX2 or below @@ -8543,6 +10646,27 @@ HWY_API VFromD PromoteTo(D di64, VFromD> v) { lo64_or_mask); } +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteInRangeTo(D d64, VFromD> v) { + const Rebind>, decltype(d64)> d32; + const RebindToSigned di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{0xFFFFFF9Du})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertInRangeTo(di32, adj_v); + return PromoteTo(d64, BitCast(d32, f32_to_i32_result)) + << PromoteTo(d64, exponent_adj); +} + namespace detail { template @@ -8583,7 +10707,7 @@ HWY_API VFromD PromoteTo(D du64, VFromD> v) { const auto adj_v = BitCast(df32, BitCast(du32, non_neg_v) - ShiftLeft<23>(exponent_adj)); - const VFromD f32_to_i32_result{_mm_cvttps_epi32(adj_v.raw)}; + const auto f32_to_i32_result = ConvertInRangeTo(di32, adj_v); const auto i32_overflow_mask = BroadcastSignBit(f32_to_i32_result); const auto overflow_result = @@ -8747,32 +10871,17 @@ HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; return VFromD{_mm_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; } template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; return VFromD{_mm_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; } template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; return VFromD{_mm_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; } @@ -8788,7 +10897,20 @@ template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { return VFromD{_mm_cvtusepi64_epi8(v.raw)}; } -#else // AVX2 or below +#else // AVX2 or below + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h for U64->I8/I16/I32 demotions on +// SSE2/SSSE3/SSE4/AVX2 as U64->I8/I16/I32 DemoteTo/ReorderDemote2To for +// SSE2/SSSE3/SSE4/AVX2 is implemented in x86_128-inl.h + +// The default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h are still used for U32->I8/I16 and +// U16->I8 demotions on SSE2/SSSE3/SSE4/AVX2 + +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) HWY_IF_NOT_T_SIZE_V(V, 8) + namespace detail { template HWY_INLINE VFromD> DemoteFromU64MaskOutResult( @@ -8847,9 +10969,28 @@ HWY_API VFromD DemoteTo(D dn, VFromD> v) { const DFromV di64; const RebindToUnsigned du64; - const auto non_neg_vals = BitCast(du64, AndNot(BroadcastSignBit(v), v)); - return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, non_neg_vals)); + const auto non_neg_vals = BitCast(du64, AndNot(BroadcastSignBit(v), v)); + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, non_neg_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const RebindToUnsigned dn_u; + return BitCast(dn, TruncateTo(dn_u, detail::DemoteFromU64Saturate(dn, v))); +} + +#if HWY_TARGET == HWY_SSE2 +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const Rebind di32; + return DemoteTo(dn, DemoteTo(di32, v)); } +#endif // HWY_TARGET == HWY_SSE2 template @@ -8875,6 +11016,16 @@ HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, return DemoteTo(dn, Combine(dt, b, a)); } +#if HWY_TARGET > HWY_AVX3 +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +#endif + #if HWY_TARGET > HWY_AVX2 template HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, @@ -8912,9 +11063,9 @@ HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); } -template -HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, - Vec128 b) { +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { const Half dnh; const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); @@ -9024,209 +11175,767 @@ HWY_API VFromD ConvertTo(D dd, VFromD> v) { // Truncates (rounds toward zero). +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + #if HWY_HAVE_FLOAT16 template +HWY_API VFromD ConvertInRangeTo(D /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttph_epi16 if any values of v[i] + // are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttph_epi16(v.raw)}; +#endif +} + +// F16 to I16 ConvertTo is generic for all vector lengths +template HWY_API VFromD ConvertTo(D di, VFromD> v) { - return detail::FixConversionOverflow( - di, v, VFromD>{_mm_cvttph_epi16(v.raw)}); + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, ConvertScalarTo(32768.0f)))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} + +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttph_epu16 if any values of v[i] + // are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttph_epu16(v.raw)}; +#endif +} + +// F16->U16 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return ConvertInRangeTo(D(), ZeroIfNegative(v)); } #endif // HWY_HAVE_FLOAT16 template +HWY_API VFromD ConvertInRangeTo(D /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("%vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttps_epi32(v.raw)}; +#endif +} + +// F32 to I32 ConvertTo is generic for all vector lengths +template HWY_API VFromD ConvertTo(D di, VFromD> v) { - return detail::FixConversionOverflow( - di, v, VFromD>{_mm_cvttps_epi32(v.raw)}); + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = RebindMask(di, Ge(v, Set(df, 2147483648.0f))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); } #if HWY_TARGET <= HWY_AVX3 template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DI(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epi64(v.raw)}; +#endif +} + +// F64 to I64 ConvertTo is generic for all vector lengths on AVX3 +template HWY_API VFromD ConvertTo(DI di, VFromD> v) { - return detail::FixConversionOverflow(di, v, - VFromD{_mm_cvttpd_epi64(v.raw)}); + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, 9.223372036854776e18))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); } template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DU(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttps_epu32(v.raw)}; +#endif +} + +// F32->U32 ConvertTo is generic for all vector lengths +template HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { - return VFromD{ - _mm_maskz_cvttps_epu32(_knot_mask8(MaskFromVec(v).raw), v.raw)}; + return ConvertInRangeTo(DU(), ZeroIfNegative(v)); } template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DU(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epu64(v.raw)}; +#endif +} + +// F64->U64 ConvertTo is generic for all vector lengths +template HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { - return VFromD{ - _mm_maskz_cvttpd_epu64(_knot_mask8(MaskFromVec(v).raw), v.raw)}; + return ConvertInRangeTo(DU(), ZeroIfNegative(v)); } #else // AVX2 or below -template -HWY_API VFromD ConvertTo(DU32 du32, VFromD> v) { +namespace detail { + +template +static HWY_INLINE VFromD ConvInRangeF32ToU32( + DU32 du32, VFromD> v, VFromD& exp_diff) { const RebindToSigned di32; const RebindToFloat df32; - const auto non_neg_v = ZeroIfNegative(v); - const auto exp_diff = Set(di32, int32_t{158}) - - BitCast(di32, ShiftRight<23>(BitCast(du32, non_neg_v))); + exp_diff = Set(du32, uint32_t{158}) - ShiftRight<23>(BitCast(du32, v)); const auto scale_down_f32_val_mask = - BitCast(du32, VecFromMask(di32, Eq(exp_diff, Zero(di32)))); + VecFromMask(du32, Eq(exp_diff, Zero(du32))); + + const auto v_scaled = + BitCast(df32, BitCast(du32, v) + ShiftLeft<23>(scale_down_f32_val_mask)); + const auto f32_to_u32_result = + BitCast(du32, ConvertInRangeTo(di32, v_scaled)); + + return f32_to_u32_result + And(f32_to_u32_result, scale_down_f32_val_mask); +} + +} // namespace detail + +// F32 to U32 ConvertInRangeTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertInRangeTo(DU32 du32, + VFromD> v) { + VFromD exp_diff; + const auto f32_to_u32_result = detail::ConvInRangeF32ToU32(du32, v, exp_diff); + return f32_to_u32_result; +} + +// F32 to U32 ConvertTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertTo(DU32 du32, VFromD> v) { + const RebindToSigned di32; + + const auto non_neg_v = ZeroIfNegative(v); + VFromD exp_diff; + const auto f32_to_u32_result = + detail::ConvInRangeF32ToU32(du32, non_neg_v, exp_diff); + + return Or(f32_to_u32_result, + BitCast(du32, BroadcastSignBit(BitCast(di32, exp_diff)))); +} + +namespace detail { + +template +HWY_API VFromD ConvAbsInRangeF64ToUI64(D64 d64, + VFromD> v, + VFromD& biased_exp) { + const RebindToSigned di64; + const RebindToUnsigned du64; + using VU64 = VFromD; + const Repartition du16; + const VU64 k1075 = Set(du64, 1075); /* biased exponent of 2^52 */ + + // Exponent indicates whether the number can be represented as int64_t. + biased_exp = BitCast(d64, ShiftRight<52>(BitCast(du64, v))); + HWY_IF_CONSTEXPR(IsSigned>()) { + biased_exp = And(biased_exp, Set(d64, TFromD{0x7FF})); + } + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + + // 16-bit saturated unsigned subtraction is also more efficient than a + // 64-bit subtraction followed by a 64-bit signed Max operation on + // SSE2/SSSE3/SSE4/AVX2. + + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU64 shift_mnt = BitCast( + du64, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU64 shift_int = BitCast( + du64, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU64 mantissa = BitCast(du64, v) & Set(du64, (1ULL << 52) - 1); + // Include implicit 1-bit. NOTE: the shift count may exceed 63; we rely on x86 + // returning zero in that case. + const VU64 int53 = (mantissa | Set(du64, 1ULL << 52)) >> shift_mnt; + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + + // For inputs less than 2^64, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 11 is true + // for any inputs that are less than 2^64. + + return BitCast(d64, int53 << shift_int); +} + +} // namespace detail - const auto v_scaled = BitCast( - df32, BitCast(du32, non_neg_v) + ShiftLeft<23>(scale_down_f32_val_mask)); - const VFromD f32_to_u32_result{ - _mm_cvttps_epi32(v_scaled.raw)}; +#if HWY_ARCH_X86_64 + +namespace detail { + +template +static HWY_INLINE int64_t SSE2ConvFirstF64LaneToI64(Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttsd_si64 with GCC if v[0] is + // not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (IsConstantX86Vec(hwy::SizeTag<1>(), v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return X86ConvertScalarFromFloat(raw_v[0]); + } +#endif + + int64_t result; + __asm__("%vcvttsd2si {%1, %0|%0, %1}" + : "=r"(result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return result; +#else + return _mm_cvttsd_si64(v.raw); +#endif +} + +} // namespace detail + +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, Vec64 v) { + return VFromD{_mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(v))}; +} +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, Vec128 v) { + const __m128i i0 = _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(v)); + const Full64 dd2; + const __m128i i1 = + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(UpperHalf(dd2, v))); + return VFromD{_mm_unpacklo_epi64(i0, i1)}; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, 9.223372036854776e18))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + using VI = VFromD; + + VI biased_exp; + const VI shifted = detail::ConvAbsInRangeF64ToUI64(di, v, biased_exp); + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + + // If the input was negative, negate the integer (two's complement). + return (shifted ^ sign_mask) - sign_mask; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + using VI = VFromD; + + VI biased_exp; + const VI shifted = detail::ConvAbsInRangeF64ToUI64(di, v, biased_exp); + +#if HWY_TARGET <= HWY_SSE4 + const auto in_range = biased_exp < Set(di, 1086); +#else + const Repartition di32; + const auto in_range = MaskFromVec(BitCast( + di, + VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) < Set(di32, 1086)))); +#endif + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, shifted, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertInRangeTo(DU du, VFromD> v) { + VFromD biased_exp; + const auto shifted = detail::ConvAbsInRangeF64ToUI64(du, v, biased_exp); + return shifted; +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertTo(DU du, VFromD> v) { + const RebindToSigned di; + using VU = VFromD; + + VU biased_exp; + const VU shifted = + detail::ConvAbsInRangeF64ToUI64(du, ZeroIfNegative(v), biased_exp); + + // Exponent indicates whether the number can be represented as uint64_t. +#if HWY_TARGET <= HWY_SSE4 + const VU out_of_range = + BitCast(du, VecFromMask(di, BitCast(di, biased_exp) > Set(di, 1086))); +#else + const Repartition di32; + const VU out_of_range = BitCast( + du, + VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) > Set(di32, 1086))); +#endif + + return (shifted | out_of_range); +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CXX14_CONSTEXPR TTo +X86ScalarNearestInt(TF flt_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TTo trunc_int_val = X86ConvertScalarFromFloat(flt_val); + const TFArith abs_val_diff = ScalarAbs( + ConvertScalarTo(ConvertScalarTo(flt_val) - + ConvertScalarTo(trunc_int_val))); + constexpr TFArith kHalf = ConvertScalarTo(0.5); + + const bool round_result_up = + ((trunc_int_val ^ ScalarShr(trunc_int_val, sizeof(TTo) * 8 - 1)) != + LimitsMax()) && + (abs_val_diff > kHalf || + (abs_val_diff == kHalf && (trunc_int_val & 1) != 0)); + return static_cast( + trunc_int_val + + (round_result_up ? (ScalarSignBit(flt_val) ? (-1) : 1) : 0)); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +// If these are in namespace detail, the x86_256/512 templates are not found. +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtps_epi32 with GCC if any values + // of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("%vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtps_epi32(v.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtph_epi16 if any values of v[i] + // are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtph_epi16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtpd_epi64(v.raw)}; +#endif +} + +#else // HWY_TARGET > HWY_AVX3 + +namespace detail { + +#if HWY_ARCH_X86_64 +template +static HWY_INLINE int64_t +SSE2ConvFirstF64LaneToNearestI64(Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtsd_si64 with GCC if v[0] is + // not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (IsConstantX86Vec(hwy::SizeTag<1>(), v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return X86ScalarNearestInt(raw_v[0]); + } +#endif + + int64_t result; + __asm__("%vcvtsd2si {%1, %0|%0, %1}" + : "=r"(result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return result; +#else + return _mm_cvtsd_si64(v.raw); +#endif +} +#endif // HWY_ARCH_X86_64 - return Or( - BitCast(du32, BroadcastSignBit(exp_diff)), - f32_to_u32_result + And(f32_to_u32_result, scale_down_f32_val_mask)); +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE VFromD SSE2NearestI64InRange( + DI64 di64, VFromD> v) { + const RebindToFloat df64; + const RebindToUnsigned du64; + using VI64 = VFromD; + + const auto mant_end = Set(df64, MantissaEnd()); + const auto is_small = Lt(Abs(v), mant_end); + + const auto adj_v = Max(v, Set(df64, -9223372036854775808.0)) + + IfThenElseZero(is_small, CopySignToAbs(mant_end, v)); + const auto adj_v_biased_exp = + And(BitCast(di64, ShiftRight<52>(BitCast(du64, adj_v))), + Set(di64, int64_t{0x7FF})); + + // We can simply subtract 1075 from adj_v_biased_exp[i] to get shift_int since + // adj_v_biased_exp[i] is at least 1075 + const VI64 shift_int = adj_v_biased_exp + Set(di64, int64_t{-1075}); + + const VI64 mantissa = BitCast(di64, adj_v) & Set(di64, (1LL << 52) - 1); + // Include implicit 1-bit if is_small[i] is 0. NOTE: the shift count may + // exceed 63; we rely on x86 returning zero in that case. + const VI64 int53 = mantissa | IfThenZeroElse(RebindMask(di64, is_small), + Set(di64, 1LL << 52)); + + const VI64 sign_mask = BroadcastSignBit(BitCast(di64, v)); + // If the input was negative, negate the integer (two's complement). + return ((int53 << shift_int) ^ sign_mask) - sign_mask; } +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +} // namespace detail #if HWY_ARCH_X86_64 template -HWY_API VFromD ConvertTo(DI di, Vec64 v) { - const Vec64 i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; - return detail::FixConversionOverflow(di, v, i0); +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec64 v) { + return VFromD{ + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToNearestI64(v))}; } template -HWY_API VFromD ConvertTo(DI di, Vec128 v) { - const __m128i i0 = _mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw)); +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec128 v) { + const __m128i i0 = + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToNearestI64(v)); const Full64 dd2; - const __m128i i1 = _mm_cvtsi64_si128(_mm_cvttsd_si64(UpperHalf(dd2, v).raw)); - return detail::FixConversionOverflow( - di, v, Vec128{_mm_unpacklo_epi64(i0, i1)}); + const __m128i i1 = _mm_cvtsi64_si128( + detail::SSE2ConvFirstF64LaneToNearestI64(UpperHalf(dd2, v))); + return VFromD{_mm_unpacklo_epi64(i0, i1)}; } #endif // HWY_ARCH_X86_64 #if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 template -HWY_API VFromD ConvertTo(DI di, VFromD> v) { - using VI = VFromD; - const RebindToUnsigned du; - using VU = VFromD; - const Repartition du16; - const VI k1075 = Set(di, 1075); /* biased exponent of 2^52 */ - - // Exponent indicates whether the number can be represented as int64_t. - const VU biased_exp = ShiftRight<52>(BitCast(du, v)) & Set(du, 0x7FF); -#if HWY_TARGET <= HWY_SSE4 - const auto in_range = BitCast(di, biased_exp) < Set(di, 1086); -#else - const Repartition di32; - const auto in_range = MaskFromVec(BitCast( - di, - VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) < Set(di32, 1086)))); -#endif - - // If we were to cap the exponent at 51 and add 2^52, the number would be in - // [2^52, 2^53) and mantissa bits could be read out directly. We need to - // round-to-0 (truncate), but changing rounding mode in MXCSR hits a - // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead - // manually shift the mantissa into place (we already have many of the - // inputs anyway). - - // Use 16-bit saturated unsigned subtraction to compute shift_mnt and - // shift_int since biased_exp[i] is a non-negative integer that is less than - // or equal to 2047. - - // 16-bit saturated unsigned subtraction is also more efficient than a - // 64-bit subtraction followed by a 64-bit signed Max operation on - // SSE2/SSSE3/SSE4/AVX2. - - // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be - // zero as the upper 48 bits of both k1075 and biased_exp are zero. - - const VU shift_mnt = BitCast( - du, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); - const VU shift_int = BitCast( - du, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); - const VU mantissa = BitCast(du, v) & Set(du, (1ULL << 52) - 1); - // Include implicit 1-bit. NOTE: the shift count may exceed 63; we rely on x86 - // returning zero in that case. - const VU int53 = (mantissa | Set(du, 1ULL << 52)) >> shift_mnt; - - // For inputs larger than 2^53 - 1, insert zeros at the bottom. - - // For inputs less than 2^63, the implicit 1-bit is guaranteed not to be - // shifted out of the left shift result below as shift_int[i] <= 10 is true - // for any inputs that are less than 2^63. - - const VU shifted = int53 << shift_int; - - // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. - const VI sign_mask = BroadcastSignBit(BitCast(di, v)); - const VI limit = Set(di, LimitsMax()) - sign_mask; - const VI magnitude = IfThenElse(in_range, BitCast(di, shifted), limit); - - // If the input was negative, negate the integer (two's complement). - return (magnitude ^ sign_mask) - sign_mask; +static HWY_INLINE VFromD NearestIntInRange(DI di, + VFromD> v) { + return detail::SSE2NearestI64InRange(di, v); } -#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 - -// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 -template -HWY_API VFromD ConvertTo(DU du, VFromD> v) { - const RebindToSigned di; - using VU = VFromD; - const Repartition du16; - const VU k1075 = Set(du, 1075); /* biased exponent of 2^52 */ +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 - const auto non_neg_v = ZeroIfNegative(v); +#endif // HWY_TARGET <= HWY_AVX3 - // Exponent indicates whether the number can be represented as int64_t. - const VU biased_exp = ShiftRight<52>(BitCast(du, non_neg_v)); -#if HWY_TARGET <= HWY_SSE4 - const VU out_of_range = - BitCast(du, VecFromMask(di, BitCast(di, biased_exp) > Set(di, 1086))); -#else - const Repartition di32; - const VU out_of_range = BitCast( - du, - VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) > Set(di32, 1086))); +template +static HWY_INLINE VFromD DemoteToNearestIntInRange( + DI, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtpd_epi32 with GCC if any values + // of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DI(), detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), int32_t{0}, int32_t{0}); + } #endif - // If we were to cap the exponent at 51 and add 2^52, the number would be in - // [2^52, 2^53) and mantissa bits could be read out directly. We need to - // round-to-0 (truncate), but changing rounding mode in MXCSR hits a - // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead - // manually shift the mantissa into place (we already have many of the - // inputs anyway). - - // Use 16-bit saturated unsigned subtraction to compute shift_mnt and - // shift_int since biased_exp[i] is a non-negative integer that is less than - // or equal to 2047. - - // 16-bit saturated unsigned subtraction is also more efficient than a - // 64-bit subtraction followed by a 64-bit signed Max operation on - // SSE2/SSSE3/SSE4/AVX2. - - // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be - // zero as the upper 48 bits of both k1075 and biased_exp are zero. + __m128i raw_result; + __asm__("%vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtpd_epi32(v.raw)}; +#endif +} - const VU shift_mnt = BitCast( - du, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); - const VU shift_int = BitCast( - du, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); - const VU mantissa = BitCast(du, non_neg_v) & Set(du, (1ULL << 52) - 1); - // Include implicit 1-bit. NOTE: the shift count may exceed 63; we rely on x86 - // returning zero in that case. - const VU int53 = (mantissa | Set(du, 1ULL << 52)) >> shift_mnt; +// F16/F32/F64 NearestInt is generic for all vector lengths +template , class DI = RebindToSigned, + HWY_IF_FLOAT_D(DF), + HWY_IF_T_SIZE_ONE_OF_D(DF, (1 << 4) | (1 << 8) | + (HWY_HAVE_FLOAT16 ? (1 << 2) : 0))> +HWY_API VFromD NearestInt(const VF v) { + const DI di; + using TI = TFromD; + using TF = TFromD; + using TFArith = If>; - // For inputs larger than 2^53 - 1, insert zeros at the bottom. + constexpr TFArith kMinOutOfRangePosVal = + static_cast(-static_cast(LimitsMin())); + static_assert(kMinOutOfRangePosVal > static_cast(0.0), + "kMinOutOfRangePosVal > 0.0 must be true"); - // For inputs less than 2^64, the implicit 1-bit is guaranteed not to be - // shifted out of the left shift result below as shift_int[i] <= 11 is true - // for any inputs that are less than 2^64. + // See comment at the first occurrence of "IfThenElse(overflow,". + // Here we are rounding, whereas previous occurrences truncate, but there is + // no difference because the previous float value is well below the max i32. + const auto overflow = RebindMask( + di, Ge(v, Set(DF(), ConvertScalarTo(kMinOutOfRangePosVal)))); + auto result = + IfThenElse(overflow, Set(di, LimitsMax()), NearestIntInRange(di, v)); - const VU shifted = int53 << shift_int; - return (shifted | out_of_range); + return result; } -#endif // HWY_TARGET <= HWY_AVX3 -template -HWY_API Vec128 NearestInt(const Vec128 v) { - const RebindToSigned> di; - return detail::FixConversionOverflow( - di, v, VFromD{_mm_cvtps_epi32(v.raw)}); +template +HWY_API VFromD DemoteToNearestInt(DI, VFromD> v) { + const DI di; + const Rebind df64; + return DemoteToNearestIntInRange(di, Min(v, Set(df64, 2147483647.0))); } // ------------------------------ Floating-point rounding (ConvertTo) @@ -9270,7 +11979,7 @@ HWY_API Vec128 Trunc(const Vec128 v) { const DFromV df; const RebindToSigned di; - const auto integer = ConvertTo(di, v); // round toward 0 + const auto integer = ConvertInRangeTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); @@ -9283,7 +11992,7 @@ HWY_API Vec128 Ceil(const Vec128 v) { const DFromV df; const RebindToSigned di; - const auto integer = ConvertTo(di, v); // round toward 0 + const auto integer = ConvertInRangeTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); // Truncating a positive non-integer ends up smaller; if so, add 1. @@ -9292,6 +12001,25 @@ HWY_API Vec128 Ceil(const Vec128 v) { return IfThenElse(detail::UseInt(v), int_f - neg1, v); } +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +template +HWY_API VFromD>> CeilInt(V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + return integer - + VecFromMask(di, RebindMask(di, And(detail::UseInt(v), int_f < v))); +} + // Toward -infinity, aka floor template HWY_API Vec128 Floor(const Vec128 v) { @@ -9299,7 +12027,7 @@ HWY_API Vec128 Floor(const Vec128 v) { const DFromV df; const RebindToSigned di; - const auto integer = ConvertTo(di, v); // round toward 0 + const auto integer = ConvertInRangeTo(di, v); // round toward 0 const auto int_f = ConvertTo(df, integer); // Truncating a negative non-integer ends up larger; if so, subtract 1. @@ -9308,6 +12036,19 @@ HWY_API Vec128 Floor(const Vec128 v) { return IfThenElse(detail::UseInt(v), int_f + neg1, v); } +template +HWY_API VFromD>> FloorInt(V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + return integer + + VecFromMask(di, RebindMask(di, And(detail::UseInt(v), int_f > v))); +} + #else // Toward nearest integer, ties to even @@ -9407,6 +12148,16 @@ HWY_API Mask128 IsNaN(const Vec128 v) { _mm_fpclass_ph_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; } +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + template HWY_API Mask128 IsInf(const Vec128 v) { return Mask128{_mm_fpclass_ph_mask( @@ -9443,8 +12194,40 @@ HWY_API Mask128 IsNaN(const Vec128 v) { #endif } +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask128{_mm_cmpunord_ps(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask128{_mm_cmpunord_pd(a.raw, b.raw)}; +#endif +} + #if HWY_TARGET <= HWY_AVX3 +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + template HWY_API Mask128 IsInf(const Vec128 v) { return Mask128{_mm_fpclass_ps_mask( @@ -9472,35 +12255,6 @@ HWY_API Mask128 IsFinite(const Vec128 v) { HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); } -#else - -template -HWY_API Mask128 IsInf(const Vec128 v) { - static_assert(IsFloat(), "Only for float"); - const DFromV d; - const RebindToSigned di; - const VFromD vi = BitCast(di, v); - // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); -} - -// Returns whether normal/subnormal/zero. -template -HWY_API Mask128 IsFinite(const Vec128 v) { - static_assert(IsFloat(), "Only for float"); - const DFromV d; - const RebindToUnsigned du; - const RebindToSigned di; // cheaper than unsigned comparison - const VFromD vu = BitCast(du, v); - // Shift left to clear the sign bit, then right so we can compare with the - // max exponent (cannot compare with MaxExponentTimes2 directly because it is - // negative and non-negative floats would be greater). MSVC seems to generate - // incorrect code if we instead add vu + vu. - const VFromD exp = - BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); - return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); -} - #endif // HWY_TARGET <= HWY_AVX3 // ================================================== CRYPTO @@ -9586,10 +12340,9 @@ HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { 1, 1, 1, 1, 1, 1, 1, 1}; const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); #endif - - alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, - 1, 2, 4, 8, 16, 32, 64, 128}; - return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); } template @@ -9644,6 +12397,20 @@ HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { #endif } +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + +#if HWY_TARGET <= HWY_AVX3 + return MFromD::FromBits(mask_bits); +#else + return detail::LoadMaskBits128(d, mask_bits); +#endif +} + template struct CompressIsPartition { #if HWY_TARGET <= HWY_AVX3 @@ -10779,243 +13546,99 @@ HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { // ------------------------------ Reductions -namespace detail { +// Nothing fully native, generic_ops-inl defines SumOfLanes and ReduceSum. -// N=1: no-op -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v) { - return v; -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v) { - return v; -} +// We provide specializations of u8x8 and u8x16, so exclude those. +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf, uint8_t>() || \ + (HWY_V_SIZE_D(D) != 8 && HWY_V_SIZE_D(D) != 16)>* = \ + nullptr -// N=2 -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v10) { - const DFromV d; - return Add(v10, Reverse2(d, v10)); +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, static_cast(GetLane(SumsOf8(v)) & 0xFF)); } -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v10) { - const DFromV d; - return Min(v10, Reverse2(d, v10)); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v10) { - const DFromV d; - return Max(v10, Reverse2(d, v10)); -} - -// N=4 (only 16/32-bit, else >128-bit) -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Add(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Add(v03_12_12_03, v12_03_03_12); -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Min(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Min(v03_12_12_03, v12_03_03_12); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Max(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Max(v03_12_12_03, v12_03_03_12); +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const Repartition d64; + VFromD sums = SumsOf8(v); + sums = SumOfLanes(d64, sums); + return Broadcast<0>(BitCast(d, sums)); } -#undef HWY_X86_IF_NOT_MINPOS #if HWY_TARGET <= HWY_SSE4 -// Skip the T_SIZE = 2 overload in favor of the following two. -#define HWY_X86_IF_NOT_MINPOS(T) \ - hwy::EnableIf()>* = nullptr +// We provide specializations of u8x8, u8x16, and u16x8, so exclude those. +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf<(!hwy::IsSame, uint8_t>() || \ + ((HWY_V_SIZE_D(D) < 8) || (HWY_V_SIZE_D(D) > 16))) && \ + (!hwy::IsSame, uint16_t>() || \ + (HWY_V_SIZE_D(D) != 16))>* = nullptr -HWY_INLINE Vec128 MinOfLanes(Vec128 v) { +template +HWY_API Vec128 MinOfLanes(D /* tag */, Vec128 v) { return Broadcast<0>(Vec128{_mm_minpos_epu16(v.raw)}); } -HWY_INLINE Vec128 MaxOfLanes(Vec128 v) { - const DFromV d; +template +HWY_API Vec128 MaxOfLanes(D d, Vec128 v) { const Vec128 max = Set(d, LimitsMax()); - return max - MinOfLanes(max - v); -} -#else -#define HWY_X86_IF_NOT_MINPOS(T) hwy::EnableIf* = nullptr -#endif // HWY_TARGET <= HWY_SSE4 - -// N=8 (only 16-bit, else >128-bit) -template -HWY_INLINE Vec128 SumOfLanes(Vec128 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Add(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Add(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Add(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Min(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Min(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Min(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Max(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Max(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Max(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} - -template -HWY_INLINE T ReduceSum(Vec128 v) { - return GetLane(SumOfLanes(v)); -} - -// u8, N=8, N=16: -HWY_INLINE uint8_t ReduceSum(Vec64 v) { - return static_cast(GetLane(SumsOf8(v)) & 0xFF); -} -HWY_INLINE Vec64 SumOfLanes(Vec64 v) { - const Full64 d; - return Set(d, ReduceSum(v)); -} -HWY_INLINE uint8_t ReduceSum(Vec128 v) { - uint64_t sums = ReduceSum(SumsOf8(v)); - return static_cast(sums & 0xFF); -} -HWY_INLINE Vec128 SumOfLanes(Vec128 v) { - const DFromV d; - return Set(d, ReduceSum(v)); -} -template -HWY_INLINE int8_t ReduceSum(const Vec128 v) { - const DFromV d; - const RebindToUnsigned du; - const auto is_neg = v < Zero(d); - - // Sum positive and negative lanes separately, then combine to get the result. - const auto positive = SumsOf8(BitCast(du, IfThenZeroElse(is_neg, v))); - const auto negative = SumsOf8(BitCast(du, IfThenElseZero(is_neg, Abs(v)))); - return static_cast(ReduceSum(positive - negative) & 0xFF); -} -template -HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { - const DFromV d; - return Set(d, ReduceSum(v)); + return max - MinOfLanes(d, max - v); } -#if HWY_TARGET <= HWY_SSE4 -HWY_INLINE Vec64 MinOfLanes(Vec64 v) { - const DFromV d; +template +HWY_API Vec64 MinOfLanes(D d, Vec64 v) { const Rebind d16; - return TruncateTo(d, MinOfLanes(PromoteTo(d16, v))); + return TruncateTo(d, MinOfLanes(d16, PromoteTo(d16, v))); } -HWY_INLINE Vec128 MinOfLanes(Vec128 v) { - const Half> d; +template +HWY_API Vec128 MinOfLanes(D d, Vec128 v) { + const Half dh; Vec64 result = - Min(MinOfLanes(UpperHalf(d, v)), MinOfLanes(LowerHalf(d, v))); - return Combine(DFromV(), result, result); + Min(MinOfLanes(dh, UpperHalf(dh, v)), MinOfLanes(dh, LowerHalf(dh, v))); + return Combine(d, result, result); } -HWY_INLINE Vec64 MaxOfLanes(Vec64 v) { - const Vec64 m(Set(DFromV(), LimitsMax())); - return m - MinOfLanes(m - v); -} -HWY_INLINE Vec128 MaxOfLanes(Vec128 v) { - const Vec128 m(Set(DFromV(), LimitsMax())); - return m - MinOfLanes(m - v); +template +HWY_API Vec64 MaxOfLanes(D d, Vec64 v) { + const Vec64 m(Set(d, LimitsMax())); + return m - MinOfLanes(d, m - v); } -#elif HWY_TARGET >= HWY_SSSE3 -template -HWY_API Vec128 MaxOfLanes(Vec128 v) { - const DFromV d; - const RepartitionToWide d16; - const RepartitionToWide d32; - Vec128 vm = Max(v, Reverse2(d, v)); - vm = Max(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); - vm = Max(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); - if (N > 8) { - const RepartitionToWide d64; - vm = Max(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); - } - return vm; +template +HWY_API Vec128 MaxOfLanes(D d, Vec128 v) { + const Vec128 m(Set(d, LimitsMax())); + return m - MinOfLanes(d, m - v); } -template -HWY_API Vec128 MinOfLanes(Vec128 v) { - const DFromV d; - const RepartitionToWide d16; - const RepartitionToWide d32; - Vec128 vm = Min(v, Reverse2(d, v)); - vm = Min(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); - vm = Min(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); - if (N > 8) { - const RepartitionToWide d64; - vm = Min(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); - } - return vm; -} +#endif // HWY_TARGET <= HWY_SSE4 + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE #endif -// Implement min/max of i8 in terms of u8 by toggling the sign bit. -template -HWY_INLINE Vec128 MinOfLanes(Vec128 v) { - const DFromV d; - const RebindToUnsigned du; - const auto mask = SignBit(du); - const auto vu = Xor(BitCast(du, v), mask); - return BitCast(d, Xor(MinOfLanes(vu), mask)); -} -template -HWY_INLINE Vec128 MaxOfLanes(Vec128 v) { - const DFromV d; - const RebindToUnsigned du; - const auto mask = SignBit(du); - const auto vu = Xor(BitCast(du, v), mask); - return BitCast(d, Xor(MaxOfLanes(vu), mask)); -} +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_LE_V(V, 16), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; -} // namespace detail + int32_t i32_bit_shuf_result = static_cast( + static_cast(_mm_bitshuffle_epi64_mask(v.raw, idx.raw))); -template -HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { - return detail::SumOfLanes(v); -} -template -HWY_API TFromD ReduceSum(D /* tag */, VFromD v) { - return detail::ReduceSum(v); -} -template -HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { - return detail::MinOfLanes(v); -} -template -HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { - return detail::MaxOfLanes(v); + return BitCast(d64, PromoteTo(du64, VFromD{_mm_cvtsi32_si128( + i32_bit_shuf_result)})); } +#endif // HWY_TARGET <= HWY_AVX3_DL // ------------------------------ Lt128 @@ -11168,6 +13791,8 @@ HWY_API V LeadingZeroCount(V v) { } // namespace hwy HWY_AFTER_NAMESPACE(); +#undef HWY_X86_IF_EMULATED_D + // Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - // the warning seems to be issued at the call site of intrinsics, i.e. our code. HWY_DIAGNOSTICS(pop) diff --git a/r/src/vendor/highway/hwy/ops/x86_256-inl.h b/r/src/vendor/highway/hwy/ops/x86_256-inl.h index 1591b11e..df09c052 100644 --- a/r/src/vendor/highway/hwy/ops/x86_256-inl.h +++ b/r/src/vendor/highway/hwy/ops/x86_256-inl.h @@ -101,6 +101,9 @@ class Vec256 { HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } @@ -191,6 +194,25 @@ HWY_INLINE __m256i BitCastToInteger(__m256d v) { return _mm256_castpd_si256(v); } +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m256i BitCastToInteger(__m256bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m256bh to a __m256i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m256bh vector + // to a __m256i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m256bh to a __m256i + return reinterpret_cast<__m256i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m256bh to a __m256i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one AVX vector type to a different AVX vector type + return BitCastScalar<__m256i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + template HWY_INLINE Vec256 BitCastToByte(Vec256 v) { return Vec256{BitCastToInteger(v.raw)}; @@ -359,6 +381,85 @@ HWY_API VFromD ResizeBitCast(D d, FromV v) { ResizeBitCast(Full128(), v).raw)}); } +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{_mm256_setr_epi8( + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15), static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), static_cast(t4), + static_cast(t5), static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), static_cast(t10), + static_cast(t11), static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{ + _mm256_setr_epi16(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7))}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm256_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, + t3, t4, t5, t6, t7)}; +} +#endif + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm256_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm256_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{ + _mm256_setr_epi64x(static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm256_setr_pd(t0, t1, t0, t1)}; +} + // ================================================== LOGICAL // ------------------------------ And @@ -367,7 +468,8 @@ template HWY_API Vec256 And(Vec256 a, Vec256 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm256_and_si256(a.raw, b.raw)}); + return BitCast(d, VFromD{_mm256_and_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); } HWY_API Vec256 And(Vec256 a, Vec256 b) { @@ -384,8 +486,8 @@ template HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast( - d, VFromD{_mm256_andnot_si256(not_mask.raw, mask.raw)}); + return BitCast(d, VFromD{_mm256_andnot_si256( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); } HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; @@ -400,7 +502,8 @@ template HWY_API Vec256 Or(Vec256 a, Vec256 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm256_or_si256(a.raw, b.raw)}); + return BitCast(d, VFromD{_mm256_or_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); } HWY_API Vec256 Or(Vec256 a, Vec256 b) { @@ -416,7 +519,8 @@ template HWY_API Vec256 Xor(Vec256 a, Vec256 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm256_xor_si256(a.raw, b.raw)}); + return BitCast(d, VFromD{_mm256_xor_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); } HWY_API Vec256 Xor(Vec256 a, Vec256 b) { @@ -431,7 +535,7 @@ template HWY_API Vec256 Not(const Vec256 v) { const DFromV d; using TU = MakeUnsigned; -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const __m256i vu = BitCast(RebindToUnsigned(), v).raw; return BitCast(d, Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); #else @@ -442,7 +546,7 @@ HWY_API Vec256 Not(const Vec256 v) { // ------------------------------ Xor3 template HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -457,7 +561,7 @@ HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { // ------------------------------ Or3 template HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -472,7 +576,7 @@ HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { // ------------------------------ OrAnd template HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -487,7 +591,7 @@ HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { // ------------------------------ IfVecThenElse template HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; @@ -589,7 +693,7 @@ HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, } // namespace detail -template +template HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); } @@ -634,7 +738,7 @@ HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, } // namespace detail -template +template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } @@ -672,7 +776,7 @@ HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, } // namespace detail -template +template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } @@ -683,13 +787,6 @@ HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; } -template -HWY_API Vec256 ZeroIfNegative(const Vec256 v) { - static_assert(IsSigned(), "Only for float"); - // AVX3 MaskFromVec only looks at the MSB - return IfThenZeroElse(MaskFromVec(v), v); -} - // ------------------------------ Mask logical namespace detail { @@ -879,6 +976,58 @@ HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, #endif } +// UnmaskedNot returns ~m.raw without zeroing out any invalid bits +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask32>(_knot_mask32(m.raw))}; +#else + return Mask256{static_cast<__mmask32>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask16>(_knot_mask16(m.raw))}; +#else + return Mask256{static_cast<__mmask16>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_knot_mask8(m.raw))}; +#else + return Mask256{static_cast<__mmask8>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Not(hwy::SizeTag<1> /*tag*/, const Mask256 m) { + // sizeof(T) == 1: simply return ~m as all 32 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<2> /*tag*/, const Mask256 m) { + // sizeof(T) == 2: simply return ~m as all 16 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<4> /*tag*/, const Mask256 m) { + // sizeof(T) == 4: simply return ~m as all 8 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<8> /*tag*/, const Mask256 m) { + // sizeof(T) == 8: need to zero out the upper 4 bits of ~m as only the lower + // 4 bits of m are valid + + // Return (~m) & 0x0F + return AndNot(hwy::SizeTag<8>(), m, Mask256::FromBits(uint64_t{0x0F})); +} + } // namespace detail template @@ -904,8 +1053,7 @@ HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { template HWY_API Mask256 Not(const Mask256 m) { // Flip only the valid bits. - constexpr size_t N = 32 / sizeof(T); - return Xor(m, Mask256::FromBits((1ull << N) - 1)); + return detail::Not(hwy::SizeTag(), m); } template @@ -913,6 +1061,53 @@ HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { return detail::ExclusiveNeither(hwy::SizeTag(), a, b); } +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask32 combined_mask = _mm512_kunpackw( + static_cast<__mmask32>(hi.raw), static_cast<__mmask32>(lo.raw)); +#else + const auto combined_mask = + ((static_cast(hi.raw) << 16) | (lo.raw & 0xFFFFu)); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask32(static_cast<__mmask32>(m.raw), 16); +#else + const auto shifted_mask = static_cast(m.raw) >> 16; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask32(static_cast<__mmask32>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask32(static_cast<__mmask32>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) >> 1)}; +#endif +} + #else // AVX2 // ------------------------------ Mask @@ -1072,7 +1267,11 @@ HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 HWY_API Mask256 operator==(Vec256 a, Vec256 b) { @@ -1105,7 +1304,11 @@ HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { @@ -1146,7 +1349,11 @@ HWY_API Mask256 operator>(Vec256 a, Vec256 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 HWY_API Mask256 operator>(Vec256 a, Vec256 b) { @@ -1161,7 +1368,11 @@ HWY_API Mask256 operator>(Vec256 a, Vec256 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 @@ -1617,7 +1828,7 @@ HWY_INLINE VFromD Iota0(D /*d*/) { template HWY_API VFromD Iota(D d, const T2 first) { - return detail::Iota0(d) + Set(d, static_cast>(first)); + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); } // ------------------------------ FirstN (Iota, Lt) @@ -1732,6 +1943,15 @@ HWY_API Vec256 operator-(Vec256 a, Vec256 b) { return Vec256{_mm256_sub_pd(a.raw, b.raw)}; } +// ------------------------------ AddSub + +HWY_API Vec256 AddSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_addsub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 AddSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_addsub_pd(a.raw, b.raw)}; +} + // ------------------------------ SumsOf8 HWY_API Vec256 SumsOf8(Vec256 v) { return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; @@ -1741,6 +1961,56 @@ HWY_API Vec256 SumsOf8AbsDiff(Vec256 a, Vec256 b) { return Vec256{_mm256_sad_epu8(a.raw, b.raw)}; } +// ------------------------------ SumsOf4 +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +HWY_INLINE Vec256 SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + Vec256 v) { + const DFromV d; + + // _mm256_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm256_maskz_dbsad_epu8 result. + return Vec256{_mm256_maskz_dbsad_epu8( + static_cast<__mmask16>(0x5555), v.raw, Zero(d).raw, 0)}; +} + +// detail::SumsOf4 for Vec256 on AVX3 is implemented in x86_512-inl.h + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ SumsOfAdjQuadAbsDiff + +template +static Vec256 SumsOfAdjQuadAbsDiff(Vec256 a, + Vec256 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + return Vec256{_mm256_mpsadbw_epu8( + a.raw, b.raw, + (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; +} + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +template +static Vec256 SumsOfShuffledQuadAbsDiff(Vec256 a, + Vec256 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec256{ + _mm256_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} +#endif + // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -1761,7 +2031,7 @@ HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; } -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { const DFromV d; const auto sum = a + b; @@ -1783,7 +2053,7 @@ HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; return IfThenElse(overflow_mask, overflow_result, sum); } -#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN // ------------------------------ SaturatedSub @@ -1805,7 +2075,7 @@ HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; } -#if HWY_TARGET <= HWY_AVX3 +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { const DFromV d; const auto diff = a - b; @@ -1827,7 +2097,7 @@ HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; return IfThenElse(overflow_mask, overflow_result, diff); } -#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN // ------------------------------ Average @@ -1860,15 +2130,12 @@ HWY_API Vec256 Abs(const Vec256 v) { HWY_API Vec256 Abs(const Vec256 v) { return Vec256{_mm256_abs_epi32(v.raw)}; } -// i64 is implemented after BroadcastSignBit. -template -HWY_API Vec256 Abs(const Vec256 v) { - const DFromV d; - const RebindToSigned di; - using TI = TFromD; - return v & BitCast(d, Set(di, static_cast(~SignMask()))); +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi64(v.raw)}; } +#endif // ------------------------------ Integer multiplication @@ -2016,14 +2283,29 @@ HWY_API Vec256 ShiftRight(Vec256 v) { // ------------------------------ RotateRight -template -HWY_API Vec256 RotateRight(const Vec256 v) { - constexpr size_t kSizeInBits = sizeof(T) * 8; - static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); +// U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8 +// RotateRight uses detail::GaloisAffine on AVX3_DL + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); if (kBits == 0) return v; - // AVX3 does not support 8/16-bit. - return Or(ShiftRight(v), - ShiftLeft(v)); +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif } template @@ -2048,6 +2330,38 @@ HWY_API Vec256 RotateRight(const Vec256 v) { #endif } +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_shrdv_epi16(a.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 Rol(Vec256 a, Vec256 b) { + return Vec256{_mm256_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Rol(Vec256 a, Vec256 b) { + return Vec256{_mm256_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_rorv_epi64(a.raw, b.raw)}; +} + +#endif + // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) HWY_API Vec256 BroadcastSignBit(const Vec256 v) { @@ -2086,16 +2400,6 @@ HWY_API Vec256 ShiftRight(const Vec256 v) { #endif } -HWY_API Vec256 Abs(const Vec256 v) { -#if HWY_TARGET <= HWY_AVX3 - return Vec256{_mm256_abs_epi64(v.raw)}; -#else - const DFromV d; - const auto zero = Zero(d); - return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); -#endif -} - // ------------------------------ IfNegativeThenElse (BroadcastSignBit) HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { @@ -2136,6 +2440,23 @@ HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { #endif } +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi8(v.raw, mask.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi16(v.raw, mask.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi32(v.raw, mask.raw)}; +} + // ------------------------------ ShiftLeftSame HWY_API Vec256 ShiftLeftSame(const Vec256 v, @@ -2330,6 +2651,25 @@ HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{_mm256_mul_pd(a.raw, b.raw)}; } +#if HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MulByFloorPow2(Vec256 a, + Vec256 b) { + return Vec256{_mm256_scalef_ph(a.raw, b.raw)}; +} +#endif + +HWY_API Vec256 MulByFloorPow2(Vec256 a, Vec256 b) { + return Vec256{_mm256_scalef_ps(a.raw, b.raw)}; +} + +HWY_API Vec256 MulByFloorPow2(Vec256 a, Vec256 b) { + return Vec256{_mm256_scalef_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + #if HWY_HAVE_FLOAT16 HWY_API Vec256 operator/(Vec256 a, Vec256 b) { return Vec256{_mm256_div_ph(a.raw, b.raw)}; @@ -2359,90 +2699,410 @@ HWY_API Vec256 ApproximateReciprocal(Vec256 v) { } #endif -// ------------------------------ Floating-point multiply-add variants +// ------------------------------ MaskedMinOr -#if HWY_HAVE_FLOAT16 +#if HWY_TARGET <= HWY_AVX3 -HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, - Vec256 add) { - return Vec256{_mm256_fmadd_ph(mul.raw, x.raw, add.raw)}; +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, - Vec256 add) { - return Vec256{_mm256_fnmadd_ph(mul.raw, x.raw, add.raw)}; +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, - Vec256 sub) { - return Vec256{_mm256_fmsub_ph(mul.raw, x.raw, sub.raw)}; +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, - Vec256 sub) { - return Vec256{_mm256_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; } +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} #endif // HWY_HAVE_FLOAT16 -HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, - Vec256 add) { -#ifdef HWY_DISABLE_BMI2_FMA - return mul * x + add; -#else - return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; -#endif +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, - Vec256 add) { -#ifdef HWY_DISABLE_BMI2_FMA - return mul * x + add; -#else - return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; -#endif +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, - Vec256 add) { -#ifdef HWY_DISABLE_BMI2_FMA - return add - mul * x; -#else - return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; -#endif +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, - Vec256 add) { -#ifdef HWY_DISABLE_BMI2_FMA - return add - mul * x; -#else - return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; -#endif +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, - Vec256 sub) { -#ifdef HWY_DISABLE_BMI2_FMA - return mul * x - sub; -#else - return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; -#endif +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, - Vec256 sub) { -#ifdef HWY_DISABLE_BMI2_FMA - return mul * x - sub; -#else - return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; -#endif +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; } -HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, - Vec256 sub) { -#ifdef HWY_DISABLE_BMI2_FMA - return Neg(mul * x) - sub; -#else - return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; -#endif +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +HWY_API Vec256 MaskedMulOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec256 MaskedMulOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaskedMulOr(Vec256 no, + Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +HWY_API Vec256 MaskedDivOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec256 MaskedDivOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaskedDivOr(Vec256 no, + Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{_mm256_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{_mm256_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{_mm256_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{_mm256_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif } HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, Vec256 sub) { @@ -2453,6 +3113,31 @@ HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, #endif } +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { + return Vec256{_mm256_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { +#ifdef HWY_DISABLE_BMI2_FMA + return AddSub(mul * x, sub_or_add); +#else + return Vec256{_mm256_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { +#ifdef HWY_DISABLE_BMI2_FMA + return AddSub(mul * x, sub_or_add); +#else + return Vec256{_mm256_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + // ------------------------------ Floating-point square root // Full precision square root @@ -2565,6 +3250,15 @@ HWY_API Mask256 IsNaN(Vec256 v) { v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; } +HWY_API Mask256 IsEitherNaN(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + HWY_API Mask256 IsInf(Vec256 v) { return Mask256{_mm256_fpclass_ph_mask( v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; @@ -2597,6 +3291,22 @@ HWY_API Mask256 IsNaN(Vec256 v) { #endif } +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_UNORD_Q)}; +#endif +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_UNORD_Q)}; +#endif +} + #if HWY_TARGET <= HWY_AVX3 HWY_API Mask256 IsInf(Vec256 v) { @@ -2621,35 +3331,6 @@ HWY_API Mask256 IsFinite(Vec256 v) { HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); } -#else - -template -HWY_API Mask256 IsInf(const Vec256 v) { - static_assert(IsFloat(), "Only for float"); - const DFromV d; - const RebindToSigned di; - const VFromD vi = BitCast(di, v); - // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. - return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); -} - -// Returns whether normal/subnormal/zero. -template -HWY_API Mask256 IsFinite(const Vec256 v) { - static_assert(IsFloat(), "Only for float"); - const DFromV d; - const RebindToUnsigned du; - const RebindToSigned di; // cheaper than unsigned comparison - const VFromD vu = BitCast(du, v); - // Shift left to clear the sign bit, then right so we can compare with the - // max exponent (cannot compare with MaxExponentTimes2 directly because it is - // negative and non-negative floats would be greater). MSVC seems to generate - // incorrect code if we instead add vu + vu. - const VFromD exp = - BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); - return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); -} - #endif // HWY_TARGET <= HWY_AVX3 // ================================================== MEMORY @@ -2662,16 +3343,13 @@ HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { _mm256_load_si256(reinterpret_cast(aligned))}; } // bfloat16_t is handled by x86_128-inl.h. -template -HWY_API Vec256 Load(D d, const float16_t* HWY_RESTRICT aligned) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API Vec256 Load(D /* tag */, + const float16_t* HWY_RESTRICT aligned) { return Vec256{_mm256_load_ph(aligned)}; -#else - const RebindToUnsigned du; - return BitCast(d, Load(du, reinterpret_cast(aligned))); -#endif // HWY_HAVE_FLOAT16 } +#endif template HWY_API Vec256 Load(D /* tag */, const float* HWY_RESTRICT aligned) { return Vec256{_mm256_load_ps(aligned)}; @@ -2686,16 +3364,12 @@ HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { return VFromD{_mm256_loadu_si256(reinterpret_cast(p))}; } // bfloat16_t is handled by x86_128-inl.h. -template -HWY_API Vec256 LoadU(D d, const float16_t* HWY_RESTRICT p) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API Vec256 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { return Vec256{_mm256_loadu_ph(p)}; -#else - const RebindToUnsigned du; - return BitCast(d, LoadU(du, reinterpret_cast(p))); -#endif // HWY_HAVE_FLOAT16 } +#endif template HWY_API Vec256 LoadU(D /* tag */, const float* HWY_RESTRICT p) { return Vec256{_mm256_loadu_ps(p)}; @@ -2756,8 +3430,8 @@ template HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, const TFromD* HWY_RESTRICT p) { const RebindToUnsigned du; // for float16_t - return BitCast( - d, VFromD{_mm256_mask_loadu_epi16(v.raw, m.raw, p)}); + return BitCast(d, VFromD{ + _mm256_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); } template @@ -2831,22 +3505,24 @@ HWY_API Vec256 MaskedLoad(Mask256 m, D d, // Loads 128 bit and duplicates into both 128-bit halves. This avoids the // 3-cycle cost of moving data between 128-bit halves and avoids port 5. template -HWY_API VFromD LoadDup128(D /* tag */, const TFromD* HWY_RESTRICT p) { +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; const Full128> d128; + const RebindToUnsigned du128; + const __m128i v128 = BitCast(du128, LoadU(d128, p)).raw; #if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the // upper half undefined) is fine because we're overwriting that anyway. // This workaround seems in turn to generate incorrect code in MSVC 2022 // (19.31), so use broadcastsi128 there. - const __m128i v128 = LoadU(d128, p).raw; - return VFromD{ - _mm256_inserti128_si256(_mm256_castsi128_si256(v128), v128, 1)}; + return BitCast(d, VFromD{_mm256_inserti128_si256( + _mm256_castsi128_si256(v128), v128, 1)}); #else // The preferred path. This is perhaps surprising, because vbroadcasti128 // with xmm input has 7 cycle latency on Intel, but Clang >= 7 is able to // pattern-match this to vbroadcastf128 with a memory operand as desired. - return VFromD{_mm256_broadcastsi128_si256(LoadU(d128, p).raw)}; + return BitCast(d, VFromD{_mm256_broadcastsi128_si256(v128)}); #endif } template @@ -2879,16 +3555,13 @@ template HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); } -template -HWY_API void Store(Vec256 v, D d, float16_t* HWY_RESTRICT aligned) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API void Store(Vec256 v, D /* tag */, + float16_t* HWY_RESTRICT aligned) { _mm256_store_ph(aligned, v.raw); -#else - const RebindToUnsigned du; - Store(BitCast(du, v), du, reinterpret_cast(aligned)); -#endif // HWY_HAVE_FLOAT16 } +#endif // HWY_HAVE_FLOAT16 template HWY_API void Store(Vec256 v, D /* tag */, float* HWY_RESTRICT aligned) { _mm256_store_ps(aligned, v.raw); @@ -2903,16 +3576,13 @@ template HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); } -template -HWY_API void StoreU(Vec256 v, D d, float16_t* HWY_RESTRICT p) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API void StoreU(Vec256 v, D /* tag */, + float16_t* HWY_RESTRICT p) { _mm256_storeu_ph(p, v.raw); -#else - const RebindToUnsigned du; - StoreU(BitCast(du, v), du, reinterpret_cast(p)); -#endif // HWY_HAVE_FLOAT16 } +#endif template HWY_API void StoreU(Vec256 v, D /* tag */, float* HWY_RESTRICT p) { _mm256_storeu_ps(p, v.raw); @@ -3140,118 +3810,124 @@ HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, // ------------------------------ Gather -template -HWY_INLINE VFromD GatherOffset(D /* tag */, - const TFromD* HWY_RESTRICT base, - Vec256 offset) { - return VFromD{_mm256_i32gather_epi32( - reinterpret_cast(base), offset.raw, 1)}; -} -template -HWY_INLINE VFromD GatherIndex(D /* tag */, - const TFromD* HWY_RESTRICT base, - Vec256 index) { - return VFromD{_mm256_i32gather_epi32( - reinterpret_cast(base), index.raw, 4)}; -} +namespace detail { -template -HWY_INLINE VFromD GatherOffset(D /* tag */, - const TFromD* HWY_RESTRICT base, - Vec256 offset) { - return VFromD{_mm256_i64gather_epi64( - reinterpret_cast(base), offset.raw, 1)}; +template +HWY_INLINE Vec256 NativeGather256(const T* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), indices.raw, kScale)}; } -template -HWY_INLINE VFromD GatherIndex(D /* tag */, - const TFromD* HWY_RESTRICT base, - Vec256 index) { - return VFromD{_mm256_i64gather_epi64( - reinterpret_cast(base), index.raw, 8)}; + +template +HWY_INLINE Vec256 NativeGather256(const T* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), indices.raw, kScale)}; } -template -HWY_API Vec256 GatherOffset(D /* tag */, const float* HWY_RESTRICT base, - Vec256 offset) { - return Vec256{_mm256_i32gather_ps(base, offset.raw, 1)}; +template +HWY_API Vec256 NativeGather256(const float* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i32gather_ps(base, indices.raw, kScale)}; } -template -HWY_API Vec256 GatherIndex(D /* tag */, const float* HWY_RESTRICT base, - Vec256 index) { - return Vec256{_mm256_i32gather_ps(base, index.raw, 4)}; + +template +HWY_API Vec256 NativeGather256(const double* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i64gather_pd(base, indices.raw, kScale)}; } -template -HWY_API Vec256 GatherOffset(D /* tag */, - const double* HWY_RESTRICT base, - Vec256 offset) { - return Vec256{_mm256_i64gather_pd(base, offset.raw, 1)}; + +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather256<1>(base, offsets); } -template -HWY_API Vec256 GatherIndex(D /* tag */, const double* HWY_RESTRICT base, - Vec256 index) { - return Vec256{_mm256_i64gather_pd(base, index.raw, 8)}; + +template +HWY_API VFromD GatherIndex(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather256)>(base, indices); } -// ------------------------------ MaskedGatherIndex +// ------------------------------ MaskedGatherIndexOr -template -HWY_INLINE VFromD MaskedGatherIndex(MFromD m, D d, - const TFromD* HWY_RESTRICT base, - Vec256 index) { +namespace detail { + +template +HWY_INLINE Vec256 NativeMaskedGatherOr256(Vec256 no, Mask256 m, + const T* HWY_RESTRICT base, + Vec256 indices) { #if HWY_TARGET <= HWY_AVX3 - return VFromD{ - _mm256_mmask_i32gather_epi32(Zero(d).raw, m.raw, index.raw, - reinterpret_cast(base), 4)}; + return Vec256{_mm256_mmask_i32gather_epi32( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; #else - return VFromD{_mm256_mask_i32gather_epi32( - Zero(d).raw, reinterpret_cast(base), index.raw, m.raw, - 4)}; + return Vec256{_mm256_mask_i32gather_epi32( + no.raw, reinterpret_cast(base), indices.raw, m.raw, + kScale)}; #endif } -template -HWY_INLINE VFromD MaskedGatherIndex(MFromD m, D d, - const TFromD* HWY_RESTRICT base, - Vec256 index) { +template +HWY_INLINE Vec256 NativeMaskedGatherOr256(Vec256 no, Mask256 m, + const T* HWY_RESTRICT base, + Vec256 indices) { #if HWY_TARGET <= HWY_AVX3 - return VFromD{_mm256_mmask_i64gather_epi64( - Zero(d).raw, m.raw, index.raw, - reinterpret_cast(base), 8)}; + return Vec256{_mm256_mmask_i64gather_epi64( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; #else // For reasons unknown, _mm256_mask_i64gather_epi64 returns all-zeros. - const RebindToFloat df; - return BitCast(d, Vec256{_mm256_mask_i64gather_pd( - Zero(df).raw, reinterpret_cast(base), - index.raw, RebindMask(df, m).raw, 8)}); + const Full256 d; + const Full256 dd; + return BitCast(d, + Vec256{_mm256_mask_i64gather_pd( + BitCast(dd, no).raw, reinterpret_cast(base), + indices.raw, RebindMask(dd, m).raw, kScale)}); #endif } -template -HWY_API Vec256 MaskedGatherIndex(MFromD m, D d, - const float* HWY_RESTRICT base, - Vec256 index) { +template +HWY_API Vec256 NativeMaskedGatherOr256(Vec256 no, + Mask256 m, + const float* HWY_RESTRICT base, + Vec256 indices) { #if HWY_TARGET <= HWY_AVX3 return Vec256{ - _mm256_mmask_i32gather_ps(Zero(d).raw, m.raw, index.raw, base, 4)}; + _mm256_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; #else return Vec256{ - _mm256_mask_i32gather_ps(Zero(d).raw, base, index.raw, m.raw, 4)}; + _mm256_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; #endif } -template -HWY_API Vec256 MaskedGatherIndex(MFromD m, D d, - const double* HWY_RESTRICT base, - Vec256 index) { +template +HWY_API Vec256 NativeMaskedGatherOr256(Vec256 no, + Mask256 m, + const double* HWY_RESTRICT base, + Vec256 indices) { #if HWY_TARGET <= HWY_AVX3 return Vec256{ - _mm256_mmask_i64gather_pd(Zero(d).raw, m.raw, index.raw, base, 8)}; + _mm256_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; #else return Vec256{ - _mm256_mask_i64gather_pd(Zero(d).raw, base, index.raw, m.raw, 8)}; + _mm256_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; #endif } +} // namespace detail + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D /*d*/, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeMaskedGatherOr256)>(no, m, base, + indices); +} + HWY_DIAGNOSTICS(pop) // ================================================== SWIZZLE @@ -3294,7 +3970,7 @@ HWY_API Vec128 LowerHalf(Vec256 v) { template HWY_API VFromD UpperHalf(D d, VFromD> v) { const RebindToUnsigned du; // for float16_t - const Twice dut; + const Twice dut; return BitCast(d, VFromD{ _mm256_extracti128_si256(BitCast(dut, v).raw, 1)}); } @@ -3375,22 +4051,16 @@ template HWY_API VFromD ZeroExtendVector(D /* tag */, VFromD> lo) { #if HWY_HAVE_ZEXT return VFromD{_mm256_zextsi128_si256(lo.raw)}; +#elif HWY_COMPILER_MSVC + // Workaround: _mm256_inserti128_si256 does not actually zero the hi part. + return VFromD{_mm256_set_m128i(_mm_setzero_si128(), lo.raw)}; #else return VFromD{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; #endif } -template -HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { - (void)d; -#if HWY_HAVE_ZEXT - return VFromD{_mm256_zextsi128_si256(lo.raw)}; -#else - return VFromD{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; -#endif // HWY_HAVE_ZEXT -} +#if HWY_HAVE_FLOAT16 template HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { -#if HWY_HAVE_FLOAT16 #if HWY_HAVE_ZEXT (void)d; return Vec256{_mm256_zextph128_ph256(lo.raw)}; @@ -3398,15 +4068,8 @@ HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { const RebindToUnsigned du; return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); #endif // HWY_HAVE_ZEXT -#else - (void)d; -#if HWY_HAVE_ZEXT - return VFromD{_mm256_zextsi128_si256(lo.raw)}; -#else - return VFromD{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; -#endif // HWY_HAVE_ZEXT -#endif // HWY_HAVE_FLOAT16 } +#endif // HWY_HAVE_FLOAT16 template HWY_API Vec256 ZeroExtendVector(D /* tag */, Vec128 lo) { #if HWY_HAVE_ZEXT @@ -3443,8 +4106,11 @@ HWY_INLINE VFromD ZeroExtendResizeBitCast( template HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { - const auto lo256 = ZeroExtendVector(d, lo); - return VFromD{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; + const RebindToUnsigned du; // for float16_t + const Half dh_u; + const auto lo256 = ZeroExtendVector(du, BitCast(dh_u, lo)); + return BitCast(d, VFromD{_mm256_inserti128_si256( + lo256.raw, BitCast(dh_u, hi).raw, 1)}); } template HWY_API Vec256 Combine(D d, Vec128 hi, Vec128 lo) { @@ -3547,8 +4213,12 @@ HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, template HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, Vec256 v) { - const Half> dh; - return Vec256{_mm256_broadcastw_epi16(LowerHalf(dh, v).raw)}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + const Half dh; + const RebindToUnsigned dh_u; + return BitCast(d, VFromD{_mm256_broadcastw_epi16( + BitCast(dh_u, LowerHalf(dh, v)).raw)}); } template @@ -3983,7 +4653,10 @@ HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, template HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { - return Vec256{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + BitCast(du, v).raw, _MM_SHUFFLE(1, 0, 3, 2))}); } HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { @@ -4022,9 +4695,9 @@ HWY_API VFromD Reverse(D d, const VFromD v) { _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100}; - const auto rev128 = TableLookupBytes(v, LoadDup128(di, kShuffle)); + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + const auto rev128 = TableLookupBytes(v, shuffle); return VFromD{ _mm256_permute4x64_epi64(rev128.raw, _MM_SHUFFLE(1, 0, 3, 2))}; #endif @@ -4053,9 +4726,9 @@ HWY_API VFromD Reverse(D d, const VFromD v) { template HWY_API VFromD Reverse4(D d, const VFromD v) { const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908}; - return BitCast(d, TableLookupBytes(v, LoadDup128(di, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); + return BitCast(d, TableLookupBytes(v, shuffle)); } // 32 bit Reverse4 defined in x86_128. @@ -4071,9 +4744,9 @@ HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { template HWY_API VFromD Reverse8(D d, const VFromD v) { const RebindToSigned di; - alignas(16) static constexpr int16_t kShuffle[8] = { - 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100}; - return BitCast(d, TableLookupBytes(v, LoadDup128(di, kShuffle))); + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); } template @@ -4162,8 +4835,12 @@ HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t const Half d2; - return VFromD{_mm256_inserti128_si256(lo.raw, LowerHalf(d2, hi).raw, 1)}; + const RebindToUnsigned du2; // for float16_t + return BitCast( + d, VFromD{_mm256_inserti128_si256( + BitCast(du, lo).raw, BitCast(du2, LowerHalf(d2, hi)).raw, 1)}); } template HWY_API Vec256 ConcatLowerLower(D d, Vec256 hi, @@ -4180,8 +4857,10 @@ HWY_API Vec256 ConcatLowerLower(D d, Vec256 hi, // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) template -HWY_API VFromD ConcatLowerUpper(D /* tag */, VFromD hi, VFromD lo) { - return VFromD{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_permute2x128_si256( + BitCast(du, lo).raw, BitCast(du, hi).raw, 0x21)}); } template HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, @@ -4196,8 +4875,10 @@ HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, // hiH,hiL loH,loL |-> hiH,loL (= outer halves) template -HWY_API VFromD ConcatUpperLower(D /* tag */, VFromD hi, VFromD lo) { - return VFromD{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_blend_epi32( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x0F)}); } template HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, @@ -4212,8 +4893,10 @@ HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, // hiH,hiL loH,loL |-> hiH,loH (= upper halves) template -HWY_API VFromD ConcatUpperUpper(D /* tag */, VFromD hi, VFromD lo) { - return VFromD{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x31)}; +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_permute2x128_si256( + BitCast(du, lo).raw, BitCast(du, hi).raw, 0x31)}); } template HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, @@ -4274,7 +4957,8 @@ HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { const Vec256 uH = ShiftRight<16>(BitCast(dw, hi)); const Vec256 uL = ShiftRight<16>(BitCast(dw, lo)); const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); - return VFromD{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + u16, _MM_SHUFFLE(3, 1, 2, 0))}); #endif } @@ -4380,7 +5064,8 @@ HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { const Vec256 uH = And(BitCast(dw, hi), mask); const Vec256 uL = And(BitCast(dw, lo), mask); const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); - return VFromD{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + u16, _MM_SHUFFLE(3, 1, 2, 0))}); #endif } @@ -4402,53 +5087,173 @@ HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { #endif } -template -HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return VFromD{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + const VFromD v2020{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v20{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; + return VFromD{ + _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +template +HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return Vec256{ + _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + (void)d; + const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; + return Vec256{ + _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ InterleaveWholeLower + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(32) static constexpr uint8_t kIdx[32] = { + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; + return VFromD{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +#endif +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return BitCast( + d, VFromD{_mm256_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; + return VFromD{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; + return VFromD{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} +#else // AVX2 +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} +#endif + +// ------------------------------ InterleaveWholeUpper + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(32) static constexpr uint8_t kIdx[32] = { + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; + return VFromD{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +#endif +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const RebindToUnsigned du; -#if HWY_TARGET <= HWY_AVX3 - alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; - return VFromD{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; -#else - const VFromD v2020{ - _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; - return BitCast(d, Vec256{_mm256_permute4x64_epi64( - BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + alignas(32) static constexpr uint16_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return BitCast( + d, VFromD{_mm256_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} -#endif +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; } -template -HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const RebindToUnsigned du; -#if HWY_TARGET <= HWY_AVX3 - alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; - return BitCast( - d, Vec256{_mm256_permutex2var_epi64( - BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); -#else - const RebindToFloat df; - const Vec256 v20{ - _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; - return VFromD{ - _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} -#endif +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; + return VFromD{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; } template -HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { -#if HWY_TARGET <= HWY_AVX3 +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const RebindToUnsigned du; - alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; - return Vec256{ - _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; -#else - (void)d; - const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; - return Vec256{ - _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; -#endif + alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; + return VFromD{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} +#else // AVX2 +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); } +#endif // ------------------------------ DupEven (InterleaveLower) @@ -4490,9 +5295,10 @@ template HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { const DFromV d; const Full256 d8; - alignas(32) static constexpr uint8_t mask[16] = { - 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; - return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); + const VFromD mask = + Dup128VecFromValues(d8, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, + 0, 0xFF, 0, 0xFF, 0); + return IfThenElse(MaskFromVec(BitCast(d, mask)), b, a); } template @@ -4505,7 +5311,8 @@ HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { #if HWY_HAVE_FLOAT16 HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { - return Vec256{_mm256_mask_blend_ph(a.raw, b.raw, 0x55)}; + return Vec256{ + _mm256_mask_blend_ph(static_cast<__mmask16>(0x5555), a.raw, b.raw)}; } #endif // HWY_HAVE_FLOAT16 @@ -4527,11 +5334,80 @@ HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; } +// -------------------------- InterleaveEven + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_epi32( + a.raw, static_cast<__mmask8>(0xAA), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0xAA), + b.raw, b.raw, + _MM_SHUFFLE(2, 2, 0, 0))}; +} +#else +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const VFromD b2_b0_a2_a0{_mm256_shuffle_ps( + BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast( + d, VFromD{_mm256_shuffle_ps( + b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// I64/U64/F64 InterleaveEven is generic for vector lengths >= 32 bytes +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// -------------------------- InterleaveOdd + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_epi32( + b.raw, static_cast<__mmask8>(0x55), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x55), + a.raw, a.raw, + _MM_SHUFFLE(3, 3, 1, 1))}; +} +#else +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const VFromD b3_b1_a3_a3{_mm256_shuffle_ps( + BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast( + d, VFromD{_mm256_shuffle_ps( + b3_b1_a3_a3.raw, b3_b1_a3_a3.raw, _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// I64/U64/F64 InterleaveOdd is generic for vector lengths >= 32 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + // ------------------------------ OddEvenBlocks template Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { - return Vec256{_mm256_blend_epi32(odd.raw, even.raw, 0xFu)}; + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_blend_epi32( + BitCast(du, odd).raw, BitCast(du, even).raw, 0xFu)}); } HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { @@ -4554,7 +5430,10 @@ HWY_API VFromD ReverseBlocks(D /*d*/, VFromD v) { // Both full template HWY_API Vec256 TableLookupBytes(Vec256 bytes, Vec256 from) { - return Vec256{_mm256_shuffle_epi8(bytes.raw, from.raw)}; + const DFromV d; + return BitCast(d, Vec256{_mm256_shuffle_epi8( + BitCast(Full256(), bytes).raw, + BitCast(Full256(), from).raw)}); } // Partial index vector @@ -5114,14 +5993,15 @@ HWY_API Vec256 Shl(hwy::UnsignedTag tag, Vec256 v, const DFromV d; #if HWY_TARGET <= HWY_AVX3_DL (void)tag; - // kMask[i] = 0xFF >> i - alignas(16) static constexpr uint8_t kMasks[16] = { - 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0x00}; + // masks[i] = 0xFF >> i + const VFromD masks = + Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, + 0, 0, 0, 0, 0, 0, 0); // kShl[i] = 1 << i - alignas(16) static constexpr uint8_t kShl[16] = {1, 2, 4, 8, 0x10, - 0x20, 0x40, 0x80, 0x00}; - v = And(v, TableLookupBytes(LoadDup128(d, kMasks), bits)); - const VFromD mul = TableLookupBytes(LoadDup128(d, kShl), bits); + const VFromD shl = Dup128VecFromValues( + d, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 0, 0, 0, 0, 0, 0, 0, 0); + v = And(v, TableLookupBytes(masks, bits)); + const VFromD mul = TableLookupBytes(shl, bits); return VFromD{_mm256_gf2p8mul_epi8(v.raw, mul.raw)}; #else const Repartition dw; @@ -5271,63 +6151,20 @@ HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { #endif } -HWY_INLINE Vec256 MulEven(const Vec256 a, - const Vec256 b) { - const Full256 du64; - const RepartitionToNarrow du32; - const auto maskL = Set(du64, 0xFFFFFFFFULL); - const auto a32 = BitCast(du32, a); - const auto b32 = BitCast(du32, b); - // Inputs for MulEven: we only need the lower 32 bits - const auto aH = Shuffle2301(a32); - const auto bH = Shuffle2301(b32); - - // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need - // the even (lower 64 bits of every 128-bit block) results. See - // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat - const auto aLbL = MulEven(a32, b32); - const auto w3 = aLbL & maskL; - - const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); - const auto w2 = t2 & maskL; - const auto w1 = ShiftRight<32>(t2); +// ------------------------------ WidenMulPairwiseAdd - const auto t = MulEven(a32, bH) + w2; - const auto k = ShiftRight<32>(t); +#if HWY_NATIVE_DOT_BF16 - const auto mulH = MulEven(aH, bH) + w1 + k; - const auto mulL = ShiftLeft<32>(t) + w3; - return InterleaveLower(mulL, mulH); +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm256_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m256bh>(a.raw), + reinterpret_cast<__m256bh>(b.raw))}; } -HWY_INLINE Vec256 MulOdd(const Vec256 a, - const Vec256 b) { - const Full256 du64; - const RepartitionToNarrow du32; - const auto maskL = Set(du64, 0xFFFFFFFFULL); - const auto a32 = BitCast(du32, a); - const auto b32 = BitCast(du32, b); - // Inputs for MulEven: we only need bits [95:64] (= upper half of input) - const auto aH = Shuffle2301(a32); - const auto bH = Shuffle2301(b32); +#endif // HWY_NATIVE_DOT_BF16 - // Same as above, but we're using the odd results (upper 64 bits per block). - const auto aLbL = MulEven(a32, b32); - const auto w3 = aLbL & maskL; - - const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); - const auto w2 = t2 & maskL; - const auto w1 = ShiftRight<32>(t2); - - const auto t = MulEven(a32, bH) + w2; - const auto k = ShiftRight<32>(t); - - const auto mulH = MulEven(aH, bH) + w1 + k; - const auto mulL = ShiftLeft<32>(t) + w3; - return InterleaveUpper(du64, mulL, mulH); -} - -// ------------------------------ WidenMulPairwiseAdd template HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec256 a, Vec256 b) { @@ -5343,7 +6180,31 @@ HWY_API VFromD SatWidenMulPairwiseAdd( return VFromD{_mm256_maddubs_epi16(a.raw, b.raw)}; } +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm256_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + // ------------------------------ ReorderWidenMulAccumulate + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm256_dpbf16_ps(sum0.raw, + reinterpret_cast<__m256bh>(a.raw), + reinterpret_cast<__m256bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + template HWY_API VFromD ReorderWidenMulAccumulate(D d, Vec256 a, Vec256 b, @@ -5461,22 +6322,91 @@ HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { #if HWY_TARGET <= HWY_AVX3 template -HWY_API VFromD PromoteTo(D di64, VFromD> v) { - const Rebind df32; - const RebindToFloat df64; - const RebindToSigned di32; +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]))}; + } +#endif - return detail::FixConversionOverflow( - di64, BitCast(df64, PromoteTo(di64, BitCast(di32, v))), - VFromD{_mm256_cvttps_epi64(v.raw)}); + __m256i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL } template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { - return VFromD{ - _mm256_maskz_cvttps_epu64(_knot_mask8(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epu64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL } #endif // HWY_TARGET <= HWY_AVX3 +// ------------------------------ PromoteEvenTo/PromoteOddTo +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(DupEven(BroadcastSignBit(v)), v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(BroadcastSignBit(v), DupOdd(v))); +} + +} // namespace detail +#endif + // ------------------------------ Demotions (full -> part w/ narrow lanes) template @@ -5565,32 +6495,17 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { template HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; return VFromD{_mm256_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; } template HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; return VFromD{_mm256_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; } template HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; return VFromD{_mm256_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; } @@ -5617,32 +6532,54 @@ HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") template HWY_API VFromD DemoteTo(D df16, Vec256 v) { - (void)df16; - return VFromD{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); } HWY_DIAGNOSTICS(pop) #endif // HWY_DISABLE_F16C +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD DemoteTo(D /*df16*/, Vec256 v) { + return VFromD{_mm256_cvtpd_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_AVX3_HAVE_F32_TO_BF16C template -HWY_API VFromD DemoteTo(D dbf16, Vec256 v) { - // TODO(janwas): _mm256_cvtneps_pbh once we have avx512bf16. - const Rebind di32; - const Rebind du32; // for logical shift right - const Rebind du16; - const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); - return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +HWY_API VFromD DemoteTo(D /*dbf16*/, Vec256 v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm256_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm256_cvtneps_pbh(v.raw))}; +#endif } template -HWY_API VFromD ReorderDemote2To(D dbf16, Vec256 a, Vec256 b) { - // TODO(janwas): _mm256_cvtne2ps_pbh once we have avx512bf16. - const RebindToUnsigned du16; - const Repartition du32; - const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); - return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec256 a, + Vec256 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m256i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm256_cvtne2ps_pbh intrinsic returns a __m256bh vector that needs to + // be bit casted to a __m256i vector + return VFromD{detail::BitCastToInteger(_mm256_cvtne2ps_pbh(b.raw, a.raw))}; +#endif } +#endif // HWY_AVX3_HAVE_F32_TO_BF16C template HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, @@ -5733,9 +6670,9 @@ HWY_API Vec256 ReorderDemote2To(D dn, Vec256 a, _MM_SHUFFLE(2, 0, 2, 0))}); } -template -HWY_API Vec256 ReorderDemote2To(D dn, Vec256 a, - Vec256 b) { +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { const Half dnh; const Repartition dn_f; @@ -5767,37 +6704,64 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { } template -HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { - const Full256 d64; - const auto clamped = detail::ClampF64ToI32Max(d64, v); - return VFromD{_mm256_cvttpd_epi32(clamped.raw)}; +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttpd_epi32(v.raw)}; +#endif } -template -HWY_API VFromD DemoteTo(D du32, Vec256 v) { #if HWY_TARGET <= HWY_AVX3 - (void)du32; - return VFromD{ - _mm256_maskz_cvttpd_epu32(_knot_mask8(MaskFromVec(v).raw), v.raw)}; -#else // AVX2 - const Rebind df64; - const RebindToUnsigned du64; - - // Clamp v[i] to a value between 0 and 4294967295 - const auto clamped = Min(ZeroIfNegative(v), Set(df64, 4294967295.0)); +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif - const auto k2_31 = Set(df64, 2147483648.0); - const auto clamped_is_ge_k2_31 = (clamped >= k2_31); - const auto clamped_lo31_f64 = - clamped - IfThenElseZero(clamped_is_ge_k2_31, k2_31); - const VFromD clamped_lo31_u32{_mm256_cvttpd_epi32(clamped_lo31_f64.raw)}; - const auto clamped_u32_msb = ShiftLeft<31>( - TruncateTo(du32, BitCast(du64, VecFromMask(df64, clamped_is_ge_k2_31)))); - return Or(clamped_lo31_u32, clamped_u32_msb); + __m128i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttpd_epu32(v.raw)}; #endif } -#if HWY_TARGET <= HWY_AVX3 template HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { return VFromD{_mm256_cvtepi64_ps(v.raw)}; @@ -5963,61 +6927,382 @@ HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { #if HWY_HAVE_FLOAT16 template -HWY_API VFromD ConvertTo(D d, Vec256 v) { - return detail::FixConversionOverflow(d, v, - VFromD{_mm256_cvttph_epi16(v.raw)}); +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi16( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[15]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL < 1200 + return VFromD{_mm256_cvttph_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttph_epu16 with GCC if any + // values of v[i] are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi16( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL < 1200 + return VFromD{_mm256_cvttph_epu16(v.raw)}; +#endif } #endif // HWY_HAVE_FLOAT16 template -HWY_API VFromD ConvertTo(D d, Vec256 v) { - return detail::FixConversionOverflow(d, v, - VFromD{_mm256_cvttps_epi32(v.raw)}); +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttps_epi32(v.raw)}; +#endif } #if HWY_TARGET <= HWY_AVX3 template -HWY_API VFromD ConvertTo(D di, Vec256 v) { - return detail::FixConversionOverflow(di, v, - VFromD{_mm256_cvttpd_epi64(v.raw)}); +HWY_API VFromD ConvertInRangeTo(D /*di*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttpd_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL } template -HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { - return VFromD{ - _mm256_maskz_cvttps_epu32(_knot_mask8(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epu32(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL } template -HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { - return VFromD{ - _mm256_maskz_cvttpd_epu64(_knot_mask8(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttpd_epu64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL } -#else // AVX2 -template -HWY_API VFromD ConvertTo(DU32 du32, VFromD> v) { - const RebindToSigned di32; - const RebindToFloat df32; - - const auto non_neg_v = ZeroIfNegative(v); - const auto exp_diff = Set(di32, int32_t{158}) - - BitCast(di32, ShiftRight<23>(BitCast(du32, non_neg_v))); - const auto scale_down_f32_val_mask = - BitCast(du32, VecFromMask(di32, Eq(exp_diff, Zero(di32)))); - - const auto v_scaled = BitCast( - df32, BitCast(du32, non_neg_v) + ShiftLeft<23>(scale_down_f32_val_mask)); - const VFromD f32_to_u32_result{ - _mm256_cvttps_epi32(v_scaled.raw)}; - - return Or( - BitCast(du32, BroadcastSignBit(exp_diff)), - f32_to_u32_result + And(f32_to_u32_result, scale_down_f32_val_mask)); +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtps_epi32 if any values of + // v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtps_epi32(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi16(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[15]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtph_epi16(v.raw)}; +#endif +} +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi64x(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtpd_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL } #endif // HWY_TARGET <= HWY_AVX3 -HWY_API Vec256 NearestInt(const Vec256 v) { - const Full256 di; - return detail::FixConversionOverflow( - di, v, Vec256{_mm256_cvtps_epi32(v.raw)}); +template +static HWY_INLINE VFromD DemoteToNearestIntInRange( + DI, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtpd_epi32(v.raw)}; +#endif } #ifndef HWY_DISABLE_F16C @@ -6035,6 +7320,15 @@ HWY_API VFromD PromoteTo(D df32, Vec128 v) { #endif // HWY_DISABLE_F16C +#if HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, Vec64 v) { + return VFromD{_mm256_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + template HWY_API VFromD PromoteTo(D df32, Vec128 v) { const Rebind du16; @@ -6120,14 +7414,14 @@ template HWY_API Vec256 AESKeyGenAssist(Vec256 v) { const Full256 d; #if HWY_TARGET <= HWY_AVX3_DL - alignas(16) static constexpr uint8_t kRconXorMask[16] = { - 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; - alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { - 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; + const VFromD rconXorMask = Dup128VecFromValues( + d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); + const VFromD rotWordShuffle = Dup128VecFromValues( + d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); const Repartition du32; const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); - const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask)); - return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); + const auto sub_word_result = AESLastRound(w13, rconXorMask); + return TableLookupBytes(sub_word_result, rotWordShuffle); #else const Half d2; return Combine(d, AESKeyGenAssist(UpperHalf(d2, v)), @@ -6387,9 +7681,9 @@ HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { 0x0303030303030303ull}; const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); - alignas(32) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, - 1, 2, 4, 8, 16, 32, 64, 128}; - return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); } template @@ -6923,6 +8217,16 @@ HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, #endif // HWY_TARGET <= HWY_AVX3 +// ------------------------------ Dup128MaskFromMaskBits + +// Generic for all vector lengths >= 32 bytes +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + const auto mh = Dup128MaskFromMaskBits(dh, mask_bits); + return CombineMasks(d, mh, mh); +} + // ------------------------------ Expand // Always define Expand/LoadExpand because generic_ops only does so for Vec128. @@ -7396,116 +8700,26 @@ HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { } #endif // HWY_TARGET <= HWY_AVX3 -// ------------------------------ Reductions - -namespace detail { - -// These functions start with each lane per 128-bit block being reduced with the -// corresponding lane in the other block, so we use the same logic as x86_128 -// but running on both blocks at the same time. There are two (64-bit) to eight -// (16-bit) lanes per block. -template -HWY_INLINE Vec256 SumOfLanes(Vec256 v10) { - const DFromV d; - return Add(v10, Reverse2(d, v10)); -} -template -HWY_INLINE Vec256 MinOfLanes(Vec256 v10) { - const DFromV d; - return Min(v10, Reverse2(d, v10)); -} -template -HWY_INLINE Vec256 MaxOfLanes(Vec256 v10) { - const DFromV d; - return Max(v10, Reverse2(d, v10)); -} - -template -HWY_INLINE Vec256 SumOfLanes(Vec256 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Add(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Add(v03_12_12_03, v12_03_03_12); -} -template -HWY_INLINE Vec256 MinOfLanes(Vec256 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Min(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Min(v03_12_12_03, v12_03_03_12); -} -template -HWY_INLINE Vec256 MaxOfLanes(Vec256 v3210) { - using V = decltype(v3210); - const DFromV d; - const V v0123 = Reverse4(d, v3210); - const V v03_12_12_03 = Max(v3210, v0123); - const V v12_03_03_12 = Reverse2(d, v03_12_12_03); - return Max(v03_12_12_03, v12_03_03_12); -} +// ------------------------------ Reductions in generic_ops -template -HWY_INLINE Vec256 SumOfLanes(Vec256 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Add(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Add(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Add(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} -template -HWY_INLINE Vec256 MinOfLanes(Vec256 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Min(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Min(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Min(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} -template -HWY_INLINE Vec256 MaxOfLanes(Vec256 v76543210) { - using V = decltype(v76543210); - const DFromV d; - // The upper half is reversed from the lower half; omit for brevity. - const V v34_25_16_07 = Max(v76543210, Reverse8(d, v76543210)); - const V v0347_1625_1625_0347 = Max(v34_25_16_07, Reverse4(d, v34_25_16_07)); - return Max(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); -} +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, 32)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; -} // namespace detail + int32_t i32_bit_shuf_result = + static_cast(_mm256_bitshuffle_epi64_mask(v.raw, idx.raw)); -// Supported for >8-bit types. Returns the broadcasted result. -template -HWY_API VFromD SumOfLanes(D /*d*/, VFromD vHL) { - const VFromD vLH = SwapAdjacentBlocks(vHL); - return detail::SumOfLanes(Add(vLH, vHL)); -} -template -HWY_API TFromD ReduceSum(D d, VFromD v) { - return GetLane(SumOfLanes(d, v)); -} -#if HWY_HAVE_FLOAT16 -template -HWY_API float16_t ReduceSum(D, VFromD v) { - return _mm256_reduce_add_ph(v.raw); -} -#endif // HWY_HAVE_FLOAT16 -template -HWY_API VFromD MinOfLanes(D /*d*/, VFromD vHL) { - const VFromD vLH = SwapAdjacentBlocks(vHL); - return detail::MinOfLanes(Min(vLH, vHL)); -} -template -HWY_API VFromD MaxOfLanes(D /*d*/, VFromD vHL) { - const VFromD vLH = SwapAdjacentBlocks(vHL); - return detail::MaxOfLanes(Max(vLH, vHL)); + return BitCast(d64, PromoteTo(du64, VFromD{_mm_cvtsi32_si128( + i32_bit_shuf_result)})); } +#endif // HWY_TARGET <= HWY_AVX3_DL -// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex +// ------------------------------ LeadingZeroCount #if HWY_TARGET <= HWY_AVX3 template ), HWY_IF_V_SIZE_V(V, 32)> diff --git a/r/src/vendor/highway/hwy/ops/x86_512-inl.h b/r/src/vendor/highway/hwy/ops/x86_512-inl.h index b5e948e1..c906b2e3 100644 --- a/r/src/vendor/highway/hwy/ops/x86_512-inl.h +++ b/r/src/vendor/highway/hwy/ops/x86_512-inl.h @@ -152,6 +152,9 @@ class Vec512 { HWY_INLINE Vec512& operator-=(const Vec512 other) { return *this = (*this - other); } + HWY_INLINE Vec512& operator%=(const Vec512 other) { + return *this = (*this % other); + } HWY_INLINE Vec512& operator&=(const Vec512 other) { return *this = (*this & other); } @@ -190,6 +193,25 @@ HWY_INLINE __m512i BitCastToInteger(__m512d v) { return _mm512_castpd_si512(v); } +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m512i BitCastToInteger(__m512bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m512bh to a __m512i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m512bh vector + // to a __m512i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m512bh to a __m512i + return reinterpret_cast<__m512i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m512bh to a __m512i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one AVX vector type to a different AVX vector type + return BitCastScalar<__m512i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + template HWY_INLINE Vec512 BitCastToByte(Vec512 v) { return Vec512{BitCastToInteger(v.raw)}; @@ -373,6 +395,132 @@ HWY_API VFromD ResizeBitCast(D d, FromV v) { BitCast(Full256(), v).raw)}); } +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + return BroadcastBlock<0>(ResizeBitCast( + d, Dup128VecFromValues(Full128>(), t0, t1, t2, t3, t4, t5, t6, + t7, t8, t9, t10, t11, t12, t13, t14, t15))); +#else + (void)d; + // Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic + // available + return VFromD{_mm512_set_epi8( + static_cast(t15), static_cast(t14), static_cast(t13), + static_cast(t12), static_cast(t11), static_cast(t10), + static_cast(t9), static_cast(t8), static_cast(t7), + static_cast(t6), static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), static_cast(t1), + static_cast(t0), static_cast(t15), static_cast(t14), + static_cast(t13), static_cast(t12), static_cast(t11), + static_cast(t10), static_cast(t9), static_cast(t8), + static_cast(t7), static_cast(t6), static_cast(t5), + static_cast(t4), static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), static_cast(t15), + static_cast(t14), static_cast(t13), static_cast(t12), + static_cast(t11), static_cast(t10), static_cast(t9), + static_cast(t8), static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), static_cast(t3), + static_cast(t2), static_cast(t1), static_cast(t0), + static_cast(t15), static_cast(t14), static_cast(t13), + static_cast(t12), static_cast(t11), static_cast(t10), + static_cast(t9), static_cast(t8), static_cast(t7), + static_cast(t6), static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), static_cast(t1), + static_cast(t0))}; +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + return BroadcastBlock<0>( + ResizeBitCast(d, Dup128VecFromValues(Full128>(), t0, t1, t2, t3, + t4, t5, t6, t7))); +#else + (void)d; + // Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic + // available + return VFromD{ + _mm512_set_epi16(static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0))}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, + t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5, + t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)}; +} +#endif + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm512_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2, + t3, t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{ + _mm512_setr_epi64(static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)}; +} + // ----------------------------- Iota namespace detail { @@ -480,7 +628,7 @@ HWY_INLINE VFromD Iota0(D /*d*/) { template HWY_API VFromD Iota(D d, const T2 first) { - return detail::Iota0(d) + Set(d, static_cast>(first)); + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); } // ================================================== LOGICAL @@ -502,7 +650,8 @@ template HWY_API Vec512 And(const Vec512 a, const Vec512 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm512_and_si512(a.raw, b.raw)}); + return BitCast(d, VFromD{_mm512_and_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); } HWY_API Vec512 And(const Vec512 a, const Vec512 b) { @@ -519,8 +668,8 @@ template HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast( - d, VFromD{_mm512_andnot_si512(not_mask.raw, mask.raw)}); + return BitCast(d, VFromD{_mm512_andnot_si512( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); } HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { @@ -537,7 +686,8 @@ template HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm512_or_si512(a.raw, b.raw)}); + return BitCast(d, VFromD{_mm512_or_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); } HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { @@ -553,7 +703,8 @@ template HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { const DFromV d; // for float16_t const RebindToUnsigned du; - return BitCast(d, VFromD{_mm512_xor_si512(a.raw, b.raw)}); + return BitCast(d, VFromD{_mm512_xor_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); } HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { @@ -566,45 +717,61 @@ HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { // ------------------------------ Xor3 template HWY_API Vec512 Xor3(Vec512 x1, Vec512 x2, Vec512 x3) { +#if !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif } // ------------------------------ Or3 template HWY_API Vec512 Or3(Vec512 o1, Vec512 o2, Vec512 o3) { +#if !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif } // ------------------------------ OrAnd template HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { +#if !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif } // ------------------------------ IfVecThenElse template HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { +#if !HWY_IS_MSAN const DFromV d; const RebindToUnsigned du; using VU = VFromD; return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif } // ------------------------------ Operator overloads (internal-only if float) @@ -752,7 +919,7 @@ HWY_API MFromD FirstN(D d, size_t n) { m.raw = static_cast(_bzhi_u64(all, n)); return m; #else - return detail::FirstN(n); + return detail::FirstN>(n); #endif // HWY_ARCH_X86_64 } @@ -790,7 +957,7 @@ HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, } // namespace detail -template +template HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); @@ -840,7 +1007,7 @@ HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, } // namespace detail -template +template HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } @@ -878,7 +1045,7 @@ HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, } // namespace detail -template +template HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } @@ -896,10 +1063,12 @@ HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { return IfThenElse(MaskFromVec(v), yes, no); } -template -HWY_API Vec512 ZeroIfNegative(const Vec512 v) { +template +HWY_API Vec512 IfNegativeThenNegOrUndefIfZero(Vec512 mask, Vec512 v) { // AVX3 MaskFromVec only looks at the MSB - return IfThenZeroElse(MaskFromVec(v), v); + const DFromV d; + return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); } // ================================================== ARITHMETIC @@ -1000,6 +1169,59 @@ HWY_API Vec512 SumsOf8AbsDiff(Vec512 a, Vec512 b) { return Vec512{_mm512_sad_epu8(a.raw, b.raw)}; } +// ------------------------------ SumsOf4 +namespace detail { + +HWY_INLINE Vec512 SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + Vec512 v) { + const DFromV d; + + // _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result. + return Vec512{_mm512_maskz_dbsad_epu8( + static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)}; +} + +// I8->I32 SumsOf4 +// Generic for all vector lengths +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX2 di32; + + // Adjust the values of v to be in the 0..255 range by adding 128 to each lane + // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then + // bitcasting the Xor result to an u8 vector. + const auto v_adj = BitCast(du, Xor(v, SignBit(d))); + + // Need to add -512 to each i32 lane of the result of the + // SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account + // for the adjustment made above. + return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) + + Set(di32, int32_t{-512}); +} + +} // namespace detail + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +template +static Vec512 SumsOfShuffledQuadAbsDiff(Vec512 a, + Vec512 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec512{ + _mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} +#endif + // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. @@ -1075,27 +1297,6 @@ HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi64(v.raw)}; } -// These aren't native instructions, they also involve AND with constant. -#if HWY_HAVE_FLOAT16 -HWY_API Vec512 Abs(const Vec512 v) { - return Vec512{_mm512_abs_ph(v.raw)}; -} -#endif // HWY_HAVE_FLOAT16 - -HWY_API Vec512 Abs(const Vec512 v) { - return Vec512{_mm512_abs_ps(v.raw)}; -} -HWY_API Vec512 Abs(const Vec512 v) { -// Workaround: _mm512_abs_pd expects __m512, so implement it ourselves. -#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 803 - const DFromV d; - const RebindToUnsigned du; - return And(v, BitCast(d, Set(du, 0x7FFFFFFFFFFFFFFFULL))); -#else - return Vec512{_mm512_abs_pd(v.raw)}; -#endif -} - // ------------------------------ ShiftLeft #if HWY_TARGET <= HWY_AVX3_DL @@ -1245,14 +1446,45 @@ HWY_API Vec512 ShiftRight(const Vec512 v) { // ------------------------------ RotateRight -template -HWY_API Vec512 RotateRight(const Vec512 v) { - constexpr size_t kSizeInBits = sizeof(T) * 8; - static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3_DL +// U8 RotateRight is generic for all vector lengths on AVX3_DL +template )> +HWY_API V RotateRight(V v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + + const Repartition> du64; + if (kBits == 0) return v; + + constexpr uint64_t kShrMatrix = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + constexpr int kShlBits = (-kBits) & 7; + constexpr uint64_t kShlMatrix = (0x0102040810204080ULL >> kShlBits) & + (0x0101010101010101ULL * (0xFF >> kShlBits)); + constexpr uint64_t kMatrix = kShrMatrix | kShlMatrix; + + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} +#else // HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); if (kBits == 0) return v; - // AVX3 does not support 8/16-bit. - return Or(ShiftRight(v), - ShiftLeft(v)); + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif } template @@ -1269,6 +1501,34 @@ HWY_API Vec512 RotateRight(const Vec512 v) { return Vec512{_mm512_ror_epi64(v.raw, kBits)}; } +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_shrdv_epi16(a.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec512 Rol(Vec512 a, Vec512 b) { + return Vec512{_mm512_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Rol(Vec512 a, Vec512 b) { + return Vec512{_mm512_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_rorv_epi64(a.raw, b.raw)}; +} + // ------------------------------ ShiftLeftSame // GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512 @@ -1617,6 +1877,21 @@ HWY_API Vec512 operator*(Vec512 a, Vec512 b) { return Vec512{_mm512_mul_pd(a.raw, b.raw)}; } +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MulByFloorPow2(Vec512 a, + Vec512 b) { + return Vec512{_mm512_scalef_ph(a.raw, b.raw)}; +} +#endif + +HWY_API Vec512 MulByFloorPow2(Vec512 a, Vec512 b) { + return Vec512{_mm512_scalef_ps(a.raw, b.raw)}; +} + +HWY_API Vec512 MulByFloorPow2(Vec512 a, Vec512 b) { + return Vec512{_mm512_scalef_pd(a.raw, b.raw)}; +} + #if HWY_HAVE_FLOAT16 HWY_API Vec512 operator/(Vec512 a, Vec512 b) { return Vec512{_mm512_div_ph(a.raw, b.raw)}; @@ -1643,6 +1918,322 @@ HWY_API Vec512 ApproximateReciprocal(Vec512 v) { return Vec512{_mm512_rcp14_pd(v.raw)}; } +// ------------------------------ MaskedMinOr + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +HWY_API Vec512 MaskedMulOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec512 MaskedMulOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaskedMulOr(Vec512 no, + Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +HWY_API Vec512 MaskedDivOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec512 MaskedDivOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaskedDivOr(Vec512 no, + Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + // ------------------------------ Floating-point multiply-add variants #if HWY_HAVE_FLOAT16 @@ -1709,6 +2300,23 @@ HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; } +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +} + +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +} + // ------------------------------ Floating-point square root // Full precision square root @@ -1873,7 +2481,11 @@ HWY_API Mask512 operator==(Vec512 a, Vec512 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 @@ -1907,7 +2519,11 @@ HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 @@ -1949,7 +2565,11 @@ HWY_API Mask512 operator>(Vec512 a, Vec512 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 @@ -1965,7 +2585,11 @@ HWY_API Mask512 operator>(Vec512 a, Vec512 b) { #if HWY_HAVE_FLOAT16 HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) } #endif // HWY_HAVE_FLOAT16 @@ -2328,11 +2952,63 @@ HWY_API Mask512 ExclusiveNeither(Mask512 a, Mask512 b) { return detail::ExclusiveNeither(hwy::SizeTag(), a, b); } +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask64 combined_mask = _mm512_kunpackd( + static_cast<__mmask64>(hi.raw), static_cast<__mmask64>(lo.raw)); +#else + const __mmask64 combined_mask = static_cast<__mmask64>( + ((static_cast(hi.raw) << 32) | (lo.raw & 0xFFFFFFFFULL))); +#endif + + return MFromD{combined_mask}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask64(static_cast<__mmask64>(m.raw), 32); +#else + const auto shifted_mask = static_cast(m.raw) >> 32; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask64(static_cast<__mmask64>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask64(static_cast<__mmask64>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) >> 1)}; +#endif +} + // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) HWY_API Vec512 BroadcastSignBit(Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Repartition> du64; + return detail::GaloisAffine(v, Set(du64, 0x8080808080808080ull)); +#else const DFromV d; return VecFromMask(v < Zero(d)); +#endif } HWY_API Vec512 BroadcastSignBit(Vec512 v) { @@ -2344,7 +3020,7 @@ HWY_API Vec512 BroadcastSignBit(Vec512 v) { } HWY_API Vec512 BroadcastSignBit(Vec512 v) { - return Vec512{_mm512_srai_epi64(v.raw, 63)}; + return ShiftRight<63>(v); } // ------------------------------ Floating-point classification (Not) @@ -2356,6 +3032,15 @@ HWY_API Mask512 IsNaN(Vec512 v) { v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; } +HWY_API Mask512 IsEitherNaN(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + HWY_API Mask512 IsInf(Vec512 v) { return Mask512{_mm512_fpclass_ph_mask(v.raw, 0x18)}; } @@ -2379,6 +3064,14 @@ HWY_API Mask512 IsNaN(Vec512 v) { v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; } +HWY_API Mask512 IsEitherNaN(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +} + HWY_API Mask512 IsInf(Vec512 v) { return Mask512{_mm512_fpclass_ps_mask( v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; @@ -2410,16 +3103,13 @@ HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { return VFromD{_mm512_load_si512(aligned)}; } // bfloat16_t is handled by x86_128-inl.h. -template -HWY_API Vec512 Load(D d, const float16_t* HWY_RESTRICT aligned) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API Vec512 Load(D /* tag */, + const float16_t* HWY_RESTRICT aligned) { return Vec512{_mm512_load_ph(aligned)}; -#else - const RebindToUnsigned du; - return BitCast(d, Load(du, reinterpret_cast(aligned))); -#endif // HWY_HAVE_FLOAT16 } +#endif // HWY_HAVE_FLOAT16 template HWY_API Vec512 Load(D /* tag */, const float* HWY_RESTRICT aligned) { return Vec512{_mm512_load_ps(aligned)}; @@ -2435,16 +3125,12 @@ HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { } // bfloat16_t is handled by x86_128-inl.h. -template -HWY_API Vec512 LoadU(D d, const float16_t* HWY_RESTRICT p) { #if HWY_HAVE_FLOAT16 - (void)d; +template +HWY_API Vec512 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { return Vec512{_mm512_loadu_ph(p)}; -#else - const RebindToUnsigned du; - return BitCast(d, LoadU(du, reinterpret_cast(p))); -#endif // HWY_HAVE_FLOAT16 } +#endif // HWY_HAVE_FLOAT16 template HWY_API Vec512 LoadU(D /* tag */, const float* HWY_RESTRICT p) { return Vec512{_mm512_loadu_ps(p)}; @@ -2506,8 +3192,9 @@ template HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, const TFromD* HWY_RESTRICT p) { const RebindToUnsigned du; // for float16_t - return VFromD{_mm512_mask_loadu_epi16( - BitCast(du, v).raw, m.raw, reinterpret_cast(p))}; + return BitCast( + d, VFromD{_mm512_mask_loadu_epi16( + BitCast(du, v).raw, m.raw, reinterpret_cast(p))}); } template @@ -2539,10 +3226,12 @@ HWY_API VFromD MaskedLoadOr(VFromD v, Mask512 m, D /* tag */, // Loads 128 bit and duplicates into both 128-bit halves. This avoids the // 3-cycle cost of moving data between 128-bit halves and avoids port 5. template -HWY_API VFromD LoadDup128(D /* tag */, - const TFromD* const HWY_RESTRICT p) { +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const RebindToUnsigned du; const Full128> d128; - return VFromD{_mm512_broadcast_i32x4(LoadU(d128, p).raw)}; + const RebindToUnsigned du128; + return BitCast(d, VFromD{_mm512_broadcast_i32x4( + BitCast(du128, LoadU(d128, p)).raw)}); } template HWY_API VFromD LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { @@ -2563,15 +3252,13 @@ HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); } // bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 template HWY_API void Store(Vec512 v, D /* tag */, float16_t* HWY_RESTRICT aligned) { -#if HWY_HAVE_FLOAT16 _mm512_store_ph(aligned, v.raw); -#else - _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); -#endif } +#endif template HWY_API void Store(Vec512 v, D /* tag */, float* HWY_RESTRICT aligned) { _mm512_store_ps(aligned, v.raw); @@ -2586,15 +3273,13 @@ HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); } // bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 template HWY_API void StoreU(Vec512 v, D /* tag */, float16_t* HWY_RESTRICT p) { -#if HWY_HAVE_FLOAT16 _mm512_storeu_ph(p, v.raw); -#else - _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); -#endif // HWY_HAVE_FLOAT16 } +#endif // HWY_HAVE_FLOAT16 template HWY_API void StoreU(Vec512 v, D /* tag */, float* HWY_RESTRICT p) { @@ -2756,84 +3441,81 @@ HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, namespace detail { template -HWY_INLINE Vec512 NativeGather(const T* HWY_RESTRICT base, - Vec512 index) { - return Vec512{_mm512_i32gather_epi32(index.raw, base, kScale)}; +HWY_INLINE Vec512 NativeGather512(const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i32gather_epi32(indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeGather(const T* HWY_RESTRICT base, - Vec512 index) { - return Vec512{_mm512_i64gather_epi64(index.raw, base, kScale)}; +HWY_INLINE Vec512 NativeGather512(const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i64gather_epi64(indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeGather(const float* HWY_RESTRICT base, - Vec512 index) { - return Vec512{_mm512_i32gather_ps(index.raw, base, kScale)}; +HWY_INLINE Vec512 NativeGather512(const float* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i32gather_ps(indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeGather(const double* HWY_RESTRICT base, - Vec512 index) { - return Vec512{_mm512_i64gather_pd(index.raw, base, kScale)}; +HWY_INLINE Vec512 NativeGather512(const double* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i64gather_pd(indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeMaskedGather(Mask512 m, - const T* HWY_RESTRICT base, - Vec512 index) { - const Full512 d; +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, Mask512 m, + const T* HWY_RESTRICT base, + Vec512 indices) { return Vec512{ - _mm512_mask_i32gather_epi32(Zero(d).raw, m.raw, index.raw, base, kScale)}; + _mm512_mask_i32gather_epi32(no.raw, m.raw, indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeMaskedGather(Mask512 m, - const T* HWY_RESTRICT base, - Vec512 index) { - const Full512 d; +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, Mask512 m, + const T* HWY_RESTRICT base, + Vec512 indices) { return Vec512{ - _mm512_mask_i64gather_epi64(Zero(d).raw, m.raw, index.raw, base, kScale)}; + _mm512_mask_i64gather_epi64(no.raw, m.raw, indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeMaskedGather(Mask512 m, - const float* HWY_RESTRICT base, - Vec512 index) { - const Full512 d; +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, + Mask512 m, + const float* HWY_RESTRICT base, + Vec512 indices) { return Vec512{ - _mm512_mask_i32gather_ps(Zero(d).raw, m.raw, index.raw, base, kScale)}; + _mm512_mask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; } template -HWY_INLINE Vec512 NativeMaskedGather(Mask512 m, - const double* HWY_RESTRICT base, - Vec512 index) { - const Full512 d; +HWY_INLINE Vec512 NativeMaskedGatherOr512( + Vec512 no, Mask512 m, const double* HWY_RESTRICT base, + Vec512 indices) { return Vec512{ - _mm512_mask_i64gather_pd(Zero(d).raw, m.raw, index.raw, base, kScale)}; + _mm512_mask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; } } // namespace detail -template -HWY_API VFromD GatherOffset(D /* tag */, const TFromD* HWY_RESTRICT base, - Vec512 offset) { - static_assert(sizeof(TFromD) == sizeof(TI), "Must match for portability"); - return detail::NativeGather<1>(base, offset); +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather512<1>(base, offsets); } -template -HWY_API VFromD GatherIndex(D /* tag */, const TFromD* HWY_RESTRICT base, - Vec512 index) { - static_assert(sizeof(TFromD) == sizeof(TI), "Must match for portability"); - return detail::NativeGather)>(base, index); + +template +HWY_API VFromD GatherIndex(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather512)>(base, indices); } -template -HWY_API VFromD MaskedGatherIndex(MFromD m, D /* tag */, - const TFromD* HWY_RESTRICT base, - Vec512 index) { - static_assert(sizeof(TFromD) == sizeof(TI), "Must match for portability"); - return detail::NativeMaskedGather)>(m, base, index); + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D /*d*/, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeMaskedGatherOr512)>(no, m, base, + indices); } HWY_DIAGNOSTICS(pop) @@ -2878,7 +3560,7 @@ HWY_API Vec256 LowerHalf(Vec512 v) { template HWY_API VFromD UpperHalf(D d, VFromD> v) { const RebindToUnsigned du; // for float16_t - const Twice dut; + const Twice dut; return BitCast(d, VFromD{ _mm512_extracti32x8_epi32(BitCast(dut, v).raw, 1)}); } @@ -2920,7 +3602,11 @@ HWY_API Vec128 ExtractBlock(Vec512 v) { template 1)>* = nullptr> HWY_API Vec128 ExtractBlock(Vec512 v) { static_assert(kBlockIdx <= 3, "Invalid block index"); - return Vec128{_mm512_extracti32x4_epi32(v.raw, kBlockIdx)}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(Full128(), + Vec128>{ + _mm512_extracti32x4_epi32(BitCast(du, v).raw, kBlockIdx)}); } template 1)>* = nullptr> @@ -2955,8 +3641,13 @@ HWY_INLINE Vec512 InsertBlock(hwy::SizeTag<0> /* blk_idx_tag */, Vec512 v, template HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, Vec512 v, Vec128 blk_to_insert) { - return Vec512{_mm512_inserti32x4(v.raw, blk_to_insert.raw, - static_cast(kBlockIdx & 3))}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + const Full128> du_blk_to_insert; + return BitCast( + d, VFromD{_mm512_inserti32x4( + BitCast(du, v).raw, BitCast(du_blk_to_insert, blk_to_insert).raw, + static_cast(kBlockIdx & 3))}); } template * = nullptr> @@ -2992,7 +3683,7 @@ HWY_API T GetLane(const Vec512 v) { // ------------------------------ ZeroExtendVector -template +template HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { #if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. (void)d; @@ -3042,11 +3733,13 @@ HWY_INLINE VFromD ZeroExtendResizeBitCast( DTo d_to, DFrom d_from, VFromD v) { const Repartition du8_from; const auto vu8 = BitCast(du8_from, v); + const RebindToUnsigned du_to; #if HWY_HAVE_ZEXT - (void)d_to; - return VFromD{_mm512_zextsi128_si512(vu8.raw)}; + return BitCast(d_to, + VFromD{_mm512_zextsi128_si512(vu8.raw)}); #else - return VFromD{_mm512_inserti32x4(Zero(d_to).raw, vu8.raw, 0)}; + return BitCast(d_to, VFromD{ + _mm512_inserti32x4(Zero(du_to).raw, vu8.raw, 0)}); #endif } @@ -3096,7 +3789,8 @@ HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { const RebindToUnsigned du; // for float16_t const Half duh; const __m512i lo512 = ZeroExtendVector(du, BitCast(duh, lo)).raw; - return VFromD{_mm512_inserti32x8(lo512, BitCast(duh, hi).raw, 1)}; + return BitCast(d, VFromD{ + _mm512_inserti32x8(lo512, BitCast(duh, hi).raw, 1)}); } template HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { @@ -3181,7 +3875,11 @@ HWY_API Vec512 Broadcast(const Vec512 v) { template HWY_API Vec512 BroadcastBlock(Vec512 v) { static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); - return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, 0x55 * kBlockIdx)}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, 0x55 * kBlockIdx)}); } template @@ -3209,7 +3907,10 @@ HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, template HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, Vec512 v) { - return Vec512{_mm512_broadcastw_epi16(ResizeBitCast(Full128(), v).raw)}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_broadcastw_epi16( + ResizeBitCast(Full128(), v).raw)}); } template @@ -3671,8 +4372,11 @@ HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template -HWY_API VFromD ConcatLowerLower(D /* tag */, VFromD hi, VFromD lo) { - return VFromD{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BABA)}); } template HWY_API VFromD ConcatLowerLower(D /* tag */, VFromD hi, VFromD lo) { @@ -3686,8 +4390,11 @@ HWY_API Vec512 ConcatLowerLower(D /* tag */, Vec512 hi, // hiH,hiL loH,loL |-> hiH,loH (= upper halves) template -HWY_API VFromD ConcatUpperUpper(D /* tag */, VFromD hi, VFromD lo) { - return VFromD{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_DCDC)}); } template HWY_API VFromD ConcatUpperUpper(D /* tag */, VFromD hi, VFromD lo) { @@ -3701,8 +4408,11 @@ HWY_API Vec512 ConcatUpperUpper(D /* tag */, Vec512 hi, // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) template -HWY_API VFromD ConcatLowerUpper(D /* tag */, VFromD hi, VFromD lo) { - return VFromD{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BADC)}); } template HWY_API VFromD ConcatLowerUpper(D /* tag */, VFromD hi, VFromD lo) { @@ -3716,11 +4426,13 @@ HWY_API Vec512 ConcatLowerUpper(D /* tag */, Vec512 hi, // hiH,hiL loH,loL |-> hiH,loL (= outer halves) template -HWY_API VFromD ConcatUpperLower(D /* tag */, VFromD hi, VFromD lo) { +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks // are efficiently loaded from 32-bit regs. const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); - return VFromD{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_mask_blend_epi16( + mask, BitCast(du, hi).raw, BitCast(du, lo).raw)}); } template HWY_API VFromD ConcatUpperLower(D /* tag */, VFromD hi, VFromD lo) { @@ -3858,27 +4570,151 @@ HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { } template -HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return VFromD{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return VFromD{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +// ------------------------------ InterleaveWholeLower + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(64) static constexpr uint8_t kIdx[64] = { + 0, 64, 1, 65, 2, 66, 3, 67, 4, 68, 5, 69, 6, 70, 7, 71, + 8, 72, 9, 73, 10, 74, 11, 75, 12, 76, 13, 77, 14, 78, 15, 79, + 16, 80, 17, 81, 18, 82, 19, 83, 20, 84, 21, 85, 22, 86, 23, 87, + 24, 88, 25, 89, 26, 90, 27, 91, 28, 92, 29, 93, 30, 94, 31, 95}; + return VFromD{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + alignas(64) static constexpr uint64_t kIdx2[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + const Repartition du64; + return VFromD{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, + Load(du64, kIdx2).raw, + InterleaveUpper(d, a, b).raw)}; +#endif +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; + return BitCast( + d, VFromD{_mm512_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return VFromD{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return VFromD{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(64) static constexpr uint8_t kIdx[64] = { + 32, 96, 33, 97, 34, 98, 35, 99, 36, 100, 37, 101, 38, 102, 39, 103, + 40, 104, 41, 105, 42, 106, 43, 107, 44, 108, 45, 109, 46, 110, 47, 111, + 48, 112, 49, 113, 50, 114, 51, 115, 52, 116, 53, 117, 54, 118, 55, 119, + 56, 120, 57, 121, 58, 122, 59, 123, 60, 124, 61, 125, 62, 126, 63, 127}; + return VFromD{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + alignas(64) static constexpr uint64_t kIdx2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + const Repartition du64; + return VFromD{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, + Load(du64, kIdx2).raw, + InterleaveUpper(d, a, b).raw)}; +#endif +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; + return BitCast( + d, VFromD{_mm512_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return VFromD{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const RebindToUnsigned du; alignas(64) static constexpr uint32_t kIdx[16] = { - 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; - return VFromD{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return VFromD{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; } template -HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const RebindToUnsigned du; - alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; - return BitCast( - d, Vec512{_mm512_permutex2var_epi64( - BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); + alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; } template -HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const RebindToUnsigned du; - alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; - return VFromD{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; + alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; } // ------------------------------ DupEven (InterleaveLower) @@ -3922,11 +4758,44 @@ HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); } +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_epi32( + a.raw, static_cast<__mmask16>(0xAAAA), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_ps(a.raw, static_cast<__mmask16>(0xAAAA), + b.raw, b.raw, + _MM_SHUFFLE(2, 2, 0, 0))}; +} +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_epi32( + b.raw, static_cast<__mmask16>(0x5555), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_ps(b.raw, static_cast<__mmask16>(0x5555), + a.raw, a.raw, + _MM_SHUFFLE(3, 3, 1, 1))}; +} + // ------------------------------ OddEvenBlocks template HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { - return Vec512{_mm512_mask_blend_epi64(__mmask8{0x33u}, odd.raw, even.raw)}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_mask_blend_epi64( + __mmask8{0x33u}, BitCast(du, odd).raw, BitCast(du, even).raw)}); } HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { @@ -3943,7 +4812,11 @@ HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { template HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { - return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_CDAB)}); } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { @@ -3957,8 +4830,11 @@ HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { // ------------------------------ ReverseBlocks template -HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { - return VFromD{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +HWY_API VFromD ReverseBlocks(D d, VFromD v) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_ABCD)}); } template HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { @@ -3974,7 +4850,10 @@ HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { // Both full template HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { - return Vec512{_mm512_shuffle_epi8(bytes.raw, indices.raw)}; + const DFromV d; + return BitCast(d, Vec512{_mm512_shuffle_epi8( + BitCast(Full512(), bytes).raw, + BitCast(Full512(), indices).raw)}); } // Partial index vector @@ -4632,6 +5511,15 @@ HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { #endif // HWY_HAVE_FLOAT16 } +#if HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, Vec128 v) { + return VFromD{_mm512_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + template HWY_API VFromD PromoteTo(D df32, Vec256 v) { const Rebind du16; @@ -4655,19 +5543,76 @@ HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { } template -HWY_API VFromD PromoteTo(D di64, VFromD> v) { - const Rebind df32; - const RebindToFloat df64; - const RebindToSigned di32; +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif - return detail::FixConversionOverflow( - di64, BitCast(df64, PromoteTo(di64, BitCast(di32, v))), - VFromD{_mm512_cvttps_epi64(v.raw)}); + __m512i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epi64(v.raw)}; +#endif } template -HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { - return VFromD{ - _mm512_maskz_cvttps_epu64(_knot_mask8(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epu64(v.raw)}; +#endif } // ------------------------------ Demotions (full -> part w/ narrow lanes) @@ -4709,8 +5654,7 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; - alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; - const auto idx32 = LoadDup128(du32, kLanes); + const VFromD idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; return LowerHalf(LowerHalf(fixed)); } @@ -4745,9 +5689,7 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; - alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, - 0, 4, 8, 12, 0, 4, 8, 12}; - const auto idx32 = LoadDup128(du32, kLanes); + const VFromD idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; return LowerHalf(LowerHalf(fixed)); } @@ -4779,32 +5721,17 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { template HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; return VFromD{_mm512_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; } template HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; return VFromD{_mm512_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; } template HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { - const auto neg_mask = MaskFromVec(v); -#if HWY_COMPILER_HAS_MASK_INTRINSICS - const __mmask8 non_neg_mask = _knot_mask8(neg_mask.raw); -#else - const __mmask8 non_neg_mask = static_cast<__mmask8>(~neg_mask.raw); -#endif + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; return VFromD{_mm512_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; } @@ -4822,32 +5749,55 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { } template -HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { +HWY_API VFromD DemoteTo(D df16, Vec512 v) { // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") - return VFromD{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); HWY_DIAGNOSTICS(pop) } +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD DemoteTo(D /*df16*/, Vec512 v) { + return VFromD{_mm512_cvtpd_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_AVX3_HAVE_F32_TO_BF16C template -HWY_API VFromD DemoteTo(D dbf16, Vec512 v) { - // TODO(janwas): _mm512_cvtneps_pbh once we have avx512bf16. - const Rebind di32; - const Rebind du32; // for logical shift right - const Rebind du16; - const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); - return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +HWY_API VFromD DemoteTo(D /*dbf16*/, Vec512 v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m256i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm512_cvtneps_pbh intrinsic returns a __m256bh vector that needs to be + // bit casted to a __m256i vector + return VFromD{detail::BitCastToInteger(_mm512_cvtneps_pbh(v.raw))}; +#endif } template -HWY_API VFromD ReorderDemote2To(D dbf16, Vec512 a, Vec512 b) { - // TODO(janwas): _mm512_cvtne2ps_pbh once we have avx512bf16. - const RebindToUnsigned du16; - const Repartition du32; - const Vec512 b_in_even = ShiftRight<16>(BitCast(du32, b)); - return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec512 a, + Vec512 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m512i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm512_cvtne2ps_pbh intrinsic returns a __m512bh vector that needs to + // be bit casted to a __m512i vector + return VFromD{detail::BitCastToInteger(_mm512_cvtne2ps_pbh(b.raw, a.raw))}; +#endif } +#endif // HWY_AVX3_HAVE_F32_TO_BF16C template HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, @@ -4935,16 +5885,77 @@ HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { } template -HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { - const Full512 d64; - const auto clamped = detail::ClampF64ToI32Max(d64, v); - return VFromD{_mm512_cvttpd_epi32(clamped.raw)}; +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epi32(v.raw)}; +#endif } template -HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { - return VFromD{ - _mm512_maskz_cvttpd_epu32(_knot_mask8(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epu32(v.raw)}; +#endif } template @@ -4962,13 +5973,12 @@ HWY_API Vec128 U8FromU32(const Vec512 v) { const DFromV d32; // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the // lowest 4 bytes. - alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, - ~0u}; - const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + const VFromD v8From32 = + Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); + const auto quads = TableLookupBytes(v, v8From32); // Gather the lowest 4 bytes of 4 128-bit blocks. - alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; - const Vec512 bytes{ - _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; + const VFromD index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); + const Vec512 bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; return LowerHalf(LowerHalf(bytes)); } @@ -4979,10 +5989,9 @@ HWY_API VFromD TruncateTo(D d, const Vec512 v) { #if HWY_TARGET <= HWY_AVX3_DL (void)d; const Full512 d8; - alignas(16) static constexpr uint8_t k8From64[16] = { - 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56}; - const Vec512 bytes{ - _mm512_permutexvar_epi8(LoadDup128(d8, k8From64).raw, v.raw)}; + const VFromD v8From64 = Dup128VecFromValues( + d8, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56); + const Vec512 bytes{_mm512_permutexvar_epi8(v8From64.raw, v.raw)}; return LowerHalf(LowerHalf(LowerHalf(bytes))); #else const Full512 d32; @@ -5018,21 +6027,19 @@ template HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { #if HWY_TARGET <= HWY_AVX3_DL const Full512 d8; - alignas(16) static constexpr uint8_t k8From32[16] = { - 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; - const Vec512 bytes{ - _mm512_permutexvar_epi8(LoadDup128(d8, k8From32).raw, v.raw)}; + const VFromD v8From32 = Dup128VecFromValues( + d8, 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + const Vec512 bytes{_mm512_permutexvar_epi8(v8From32.raw, v.raw)}; #else const Full512 d32; // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the // lowest 4 bytes. - alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, - ~0u}; - const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + const VFromD v8From32 = + Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); + const auto quads = TableLookupBytes(v, v8From32); // Gather the lowest 4 bytes of 4 128-bit blocks. - alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; - const Vec512 bytes{ - _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; + const VFromD index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); + const Vec512 bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; #endif return LowerHalf(LowerHalf(bytes)); } @@ -5061,9 +6068,9 @@ HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; #else const Full512 d32; - alignas(16) static constexpr uint32_t k16From32[4] = { - 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u}; - const auto quads = TableLookupBytes(v, LoadDup128(d32, k16From32)); + const VFromD v16From32 = Dup128VecFromValues( + d32, 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u); + const auto quads = TableLookupBytes(v, v16From32); alignas(64) static constexpr uint32_t kIndex32[16] = { 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; const Vec512 bytes{ @@ -5108,36 +6115,491 @@ HWY_API VFromD ConvertTo(D /* tag*/, Vec512 v) { // Truncates (rounds toward zero). #if HWY_HAVE_FLOAT16 template -HWY_API VFromD ConvertTo(D d, Vec512 v) { - return detail::FixConversionOverflow(d, v, - VFromD{_mm512_cvttph_epi16(v.raw)}); +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_set_epi16(detail::X86ConvertScalarFromFloat(raw_v[31]), + detail::X86ConvertScalarFromFloat(raw_v[30]), + detail::X86ConvertScalarFromFloat(raw_v[29]), + detail::X86ConvertScalarFromFloat(raw_v[28]), + detail::X86ConvertScalarFromFloat(raw_v[27]), + detail::X86ConvertScalarFromFloat(raw_v[26]), + detail::X86ConvertScalarFromFloat(raw_v[25]), + detail::X86ConvertScalarFromFloat(raw_v[24]), + detail::X86ConvertScalarFromFloat(raw_v[23]), + detail::X86ConvertScalarFromFloat(raw_v[22]), + detail::X86ConvertScalarFromFloat(raw_v[21]), + detail::X86ConvertScalarFromFloat(raw_v[20]), + detail::X86ConvertScalarFromFloat(raw_v[19]), + detail::X86ConvertScalarFromFloat(raw_v[18]), + detail::X86ConvertScalarFromFloat(raw_v[17]), + detail::X86ConvertScalarFromFloat(raw_v[16]), + detail::X86ConvertScalarFromFloat(raw_v[15]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[0]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttph_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttph_epu16 with GCC if any + // values of v[i] are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_set_epi16( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[31])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[30])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[29])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[28])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[27])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[26])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[25])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[24])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[23])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[22])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[21])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[20])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[19])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[18])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[17])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[16])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttph_epu16(v.raw)}; +#endif } #endif // HWY_HAVE_FLOAT16 template -HWY_API VFromD ConvertTo(D d, Vec512 v) { - return detail::FixConversionOverflow(d, v, - VFromD{_mm512_cvttps_epi32(v.raw)}); +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[15]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epi32(v.raw)}; +#endif } template -HWY_API VFromD ConvertTo(D di, Vec512 v) { - return detail::FixConversionOverflow(di, v, - VFromD{_mm512_cvttpd_epi64(v.raw)}); +HWY_API VFromD ConvertInRangeTo(D /*di*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epi64(v.raw)}; +#endif } template -HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { - return VFromD{ - _mm512_maskz_cvttps_epu32(_knot_mask16(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epu32(v.raw)}; +#endif } template -HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { - return VFromD{ - _mm512_maskz_cvttpd_epu64(_knot_mask8(MaskFromVec(v).raw), v.raw)}; +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epu64(v.raw)}; +#endif +} + +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[15]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtps_epi32(v.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_set_epi16(detail::X86ScalarNearestInt(raw_v[31]), + detail::X86ScalarNearestInt(raw_v[30]), + detail::X86ScalarNearestInt(raw_v[29]), + detail::X86ScalarNearestInt(raw_v[28]), + detail::X86ScalarNearestInt(raw_v[27]), + detail::X86ScalarNearestInt(raw_v[26]), + detail::X86ScalarNearestInt(raw_v[25]), + detail::X86ScalarNearestInt(raw_v[24]), + detail::X86ScalarNearestInt(raw_v[23]), + detail::X86ScalarNearestInt(raw_v[22]), + detail::X86ScalarNearestInt(raw_v[21]), + detail::X86ScalarNearestInt(raw_v[20]), + detail::X86ScalarNearestInt(raw_v[19]), + detail::X86ScalarNearestInt(raw_v[18]), + detail::X86ScalarNearestInt(raw_v[17]), + detail::X86ScalarNearestInt(raw_v[16]), + detail::X86ScalarNearestInt(raw_v[15]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[0]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtph_epi16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_setr_epi64(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtpd_epi64(v.raw)}; +#endif } -HWY_API Vec512 NearestInt(const Vec512 v) { - const Full512 di; - return detail::FixConversionOverflow( - di, v, Vec512{_mm512_cvtps_epi32(v.raw)}); +template +static HWY_INLINE VFromD DemoteToNearestIntInRange(DI /* tag */, + Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtpd_epi32(v.raw)}; +#endif } // ================================================== CRYPTO @@ -5198,14 +6660,14 @@ template HWY_API Vec512 AESKeyGenAssist(Vec512 v) { const Full512 d; #if HWY_TARGET <= HWY_AVX3_DL - alignas(16) static constexpr uint8_t kRconXorMask[16] = { - 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; - alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { - 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; + const VFromD rconXorMask = Dup128VecFromValues( + d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); + const VFromD rotWordShuffle = Dup128VecFromValues( + d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); const Repartition du32; const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); - const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask)); - return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); + const auto sub_word_result = AESLastRound(w13, rconXorMask); + return TableLookupBytes(sub_word_result, rotWordShuffle); #else const Half d2; return Combine(d, AESKeyGenAssist(UpperHalf(d2, v)), @@ -5253,6 +6715,29 @@ HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { // ================================================== MISC +// ------------------------------ SumsOfAdjQuadAbsDiff (Broadcast, +// SumsOfAdjShufQuadAbsDiff) + +template +static Vec512 SumsOfAdjQuadAbsDiff(Vec512 a, + Vec512 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + + const DFromV d; + const RepartitionToWideX2 du32; + + // While AVX3 does not have a _mm512_mpsadbw_epu8 intrinsic, the + // SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on + // AVX3 using SumsOfShuffledQuadAbsDiff and U32 Broadcast. + return SumsOfShuffledQuadAbsDiff( + a, BitCast(d, Broadcast(BitCast(du32, b)))); +} + +#if !HWY_IS_MSAN // ------------------------------ I32/I64 SaturatedAdd (MaskFromVec) HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { @@ -5300,6 +6785,7 @@ HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; return IfThenElse(overflow_mask, overflow_result, diff); } +#endif // !HWY_IS_MSAN // ------------------------------ Mask testing @@ -5446,6 +6932,15 @@ HWY_API intptr_t FindLastTrue(D d, MFromD mask) { // ------------------------------ Compress +#ifndef HWY_X86_SLOW_COMPRESS_STORE // allow override +// Slow on Zen4 and SPR, faster if we emulate via Compress(). +#if HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR +#define HWY_X86_SLOW_COMPRESS_STORE 1 +#else +#define HWY_X86_SLOW_COMPRESS_STORE 0 +#endif +#endif // HWY_X86_SLOW_COMPRESS_STORE + // Always implement 8-bit here even if we lack VBMI2 because we can do better // than generic_ops (8 at a time) via the native 32-bit compress (16 at a time). #ifdef HWY_NATIVE_COMPRESS8 @@ -5485,8 +6980,8 @@ HWY_INLINE Vec512 NativeCompress(const Vec512 v, return Vec512{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; } -// Slow on Zen4, do not even define these to prevent accidental usage. -#if HWY_TARGET != HWY_AVX3_ZEN4 +// Do not even define these to prevent accidental usage. +#if !HWY_X86_SLOW_COMPRESS_STORE template HWY_INLINE void NativeCompressStore(Vec128 v, @@ -5518,7 +7013,7 @@ HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); } -#endif // HWY_TARGET != HWY_AVX3_ZEN4 +#endif // HWY_X86_SLOW_COMPRESS_STORE HWY_INLINE Vec512 NativeExpand(Vec512 v, Mask512 mask) { @@ -5559,8 +7054,8 @@ HWY_INLINE Vec512 NativeCompress(Vec512 v, } // We use table-based compress for 64-bit lanes, see CompressIsPartition. -// Slow on Zen4, do not even define these to prevent accidental usage. -#if HWY_TARGET != HWY_AVX3_ZEN4 +// Do not even define these to prevent accidental usage. +#if !HWY_X86_SLOW_COMPRESS_STORE template HWY_INLINE void NativeCompressStore(Vec128 v, @@ -5621,7 +7116,7 @@ HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); } -#endif // HWY_TARGET != HWY_AVX3_ZEN4 +#endif // HWY_X86_SLOW_COMPRESS_STORE HWY_INLINE Vec512 NativeExpand(Vec512 v, Mask512 mask) { @@ -6081,7 +7576,7 @@ HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { template HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { -#if HWY_TARGET == HWY_AVX3_ZEN4 +#if HWY_X86_SLOW_COMPRESS_STORE StoreU(Compress(v, mask), d, unaligned); #else const RebindToUnsigned du; @@ -6093,7 +7588,7 @@ HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, #else detail::EmuCompressStore(BitCast(du, v), mu, du, pu); #endif -#endif // HWY_TARGET != HWY_AVX3_ZEN4 +#endif // HWY_X86_SLOW_COMPRESS_STORE const size_t count = CountTrue(d, mask); detail::MaybeUnpoison(unaligned, count); return count; @@ -6103,7 +7598,7 @@ template HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { -#if HWY_TARGET == HWY_AVX3_ZEN4 +#if HWY_X86_SLOW_COMPRESS_STORE StoreU(Compress(v, mask), d, unaligned); #else const RebindToUnsigned du; @@ -6111,7 +7606,7 @@ HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, using TU = TFromD; TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); detail::NativeCompressStore(BitCast(du, v), mu, pu); -#endif // HWY_TARGET != HWY_AVX3_ZEN4 +#endif // HWY_X86_SLOW_COMPRESS_STORE const size_t count = CountTrue(d, mask); detail::MaybeUnpoison(unaligned, count); return count; @@ -6121,12 +7616,12 @@ HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, template HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { -#if HWY_TARGET == HWY_AVX3_ZEN4 +#if HWY_X86_SLOW_COMPRESS_STORE StoreU(Compress(v, mask), d, unaligned); #else (void)d; detail::NativeCompressStore(v, mask, unaligned); -#endif // HWY_TARGET != HWY_AVX3_ZEN4 +#endif // HWY_X86_SLOW_COMPRESS_STORE const size_t count = PopCount(uint64_t{mask.raw}); detail::MaybeUnpoison(unaligned, count); return count; @@ -6139,7 +7634,7 @@ HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, // Native CompressStore already does the blending at no extra cost (latency // 11, rthroughput 2 - same as compress plus store). if (HWY_TARGET == HWY_AVX3_DL || - (HWY_TARGET != HWY_AVX3_ZEN4 && sizeof(TFromD) > 2)) { + (!HWY_X86_SLOW_COMPRESS_STORE && sizeof(TFromD) > 2)) { return CompressStore(v, m, d, unaligned); } else { const size_t count = CountTrue(d, m); @@ -6165,7 +7660,10 @@ namespace detail { // Type-safe wrapper. template <_MM_PERM_ENUM kPerm, typename T> Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { - return Vec512{_mm512_shuffle_i64x2(lo.raw, hi.raw, kPerm)}; + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_shuffle_i64x2( + BitCast(du, lo).raw, BitCast(du, hi).raw, kPerm)}); } template <_MM_PERM_ENUM kPerm> Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { @@ -6345,7 +7843,7 @@ HWY_API Mask512 SetOnlyFirst(Mask512 mask) { static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; } -// ------------------------------ Shl (LoadDup128) +// ------------------------------ Shl (Dup128VecFromValues) HWY_API Vec512 operator<<(Vec512 v, Vec512 bits) { return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; @@ -6356,13 +7854,15 @@ HWY_API Vec512 operator<<(Vec512 v, Vec512 bits) { const DFromV d; #if HWY_TARGET <= HWY_AVX3_DL // kMask[i] = 0xFF >> i - alignas(16) static constexpr uint8_t kMasks[16] = { - 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0x00}; + const VFromD masks = + Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, + 0, 0, 0, 0, 0, 0, 0); // kShl[i] = 1 << i - alignas(16) static constexpr uint8_t kShl[16] = {0x01, 0x02, 0x04, 0x08, - 0x10, 0x20, 0x40, 0x80}; - v = And(v, TableLookupBytes(LoadDup128(d, kMasks), bits)); - const VFromD mul = TableLookupBytes(LoadDup128(d, kShl), bits); + const VFromD shl = + Dup128VecFromValues(d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0, + 0, 0, 0, 0, 0, 0, 0); + v = And(v, TableLookupBytes(masks, bits)); + const VFromD mul = TableLookupBytes(shl, bits); return VFromD{_mm512_gf2p8mul_epi8(v.raw, mul.raw)}; #else const Repartition dw; @@ -6457,65 +7957,18 @@ HWY_API Vec512 operator>>(const Vec512 v, return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; } -// ------------------------------ MulEven/Odd (Shuffle2301, InterleaveLower) - -HWY_INLINE Vec512 MulEven(const Vec512 a, - const Vec512 b) { - const DFromV du64; - const RepartitionToNarrow du32; - const auto maskL = Set(du64, 0xFFFFFFFFULL); - const auto a32 = BitCast(du32, a); - const auto b32 = BitCast(du32, b); - // Inputs for MulEven: we only need the lower 32 bits - const auto aH = Shuffle2301(a32); - const auto bH = Shuffle2301(b32); - - // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need - // the even (lower 64 bits of every 128-bit block) results. See - // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat - const auto aLbL = MulEven(a32, b32); - const auto w3 = aLbL & maskL; - - const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); - const auto w2 = t2 & maskL; - const auto w1 = ShiftRight<32>(t2); - - const auto t = MulEven(a32, bH) + w2; - const auto k = ShiftRight<32>(t); - - const auto mulH = MulEven(aH, bH) + w1 + k; - const auto mulL = ShiftLeft<32>(t) + w3; - return InterleaveLower(mulL, mulH); -} - -HWY_INLINE Vec512 MulOdd(const Vec512 a, - const Vec512 b) { - const DFromV du64; - const RepartitionToNarrow du32; - const auto maskL = Set(du64, 0xFFFFFFFFULL); - const auto a32 = BitCast(du32, a); - const auto b32 = BitCast(du32, b); - // Inputs for MulEven: we only need bits [95:64] (= upper half of input) - const auto aH = Shuffle2301(a32); - const auto bH = Shuffle2301(b32); - - // Same as above, but we're using the odd results (upper 64 bits per block). - const auto aLbL = MulEven(a32, b32); - const auto w3 = aLbL & maskL; - - const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); - const auto w2 = t2 & maskL; - const auto w1 = ShiftRight<32>(t2); - - const auto t = MulEven(a32, bH) + w2; - const auto k = ShiftRight<32>(t); +// ------------------------------ WidenMulPairwiseAdd - const auto mulH = MulEven(aH, bH) + w1 + k; - const auto mulL = ShiftLeft<32>(t) + w3; - return InterleaveUpper(du64, mulL, mulH); +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm512_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m512bh>(a.raw), + reinterpret_cast<__m512bh>(b.raw))}; } +#endif // HWY_NATIVE_DOT_BF16 -// ------------------------------ WidenMulPairwiseAdd template HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec512 a, Vec512 b) { @@ -6523,7 +7976,6 @@ HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec512 a, } // ------------------------------ SatWidenMulPairwiseAdd - template HWY_API VFromD SatWidenMulPairwiseAdd( DI16 /* tag */, VFromD> a, @@ -6531,7 +7983,30 @@ HWY_API VFromD SatWidenMulPairwiseAdd( return VFromD{_mm512_maddubs_epi16(a.raw, b.raw)}; } +// ------------------------------ SatWidenMulPairwiseAccumulate +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm512_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + // ------------------------------ ReorderWidenMulAccumulate + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm512_dpbf16_ps(sum0.raw, + reinterpret_cast<__m512bh>(a.raw), + reinterpret_cast<__m512bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + template HWY_API VFromD ReorderWidenMulAccumulate(D d, Vec512 a, Vec512 b, @@ -6570,161 +8045,47 @@ HWY_API VFromD SumOfMulQuadAccumulate( // ------------------------------ Reductions -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return _mm512_reduce_add_epi32(v.raw); -} -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return _mm512_reduce_add_epi64(v.raw); -} -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return static_cast(_mm512_reduce_add_epi32(v.raw)); -} -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return static_cast(_mm512_reduce_add_epi64(v.raw)); -} -#if HWY_HAVE_FLOAT16 -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return _mm512_reduce_add_ph(v.raw); -} -#endif // HWY_HAVE_FLOAT16 -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return _mm512_reduce_add_ps(v.raw); -} -template -HWY_API TFromD ReduceSum(D, VFromD v) { - return _mm512_reduce_add_pd(v.raw); -} -template -HWY_API TFromD ReduceSum(D d, VFromD v) { - const RepartitionToWide d32; - const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); - const auto odd = ShiftRight<16>(BitCast(d32, v)); - const auto sum = ReduceSum(d32, even + odd); - return static_cast(sum); -} -template -HWY_API TFromD ReduceSum(D d, VFromD v) { - const RepartitionToWide d32; - // Sign-extend - const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); - const auto odd = ShiftRight<16>(BitCast(d32, v)); - const auto sum = ReduceSum(d32, even + odd); - return static_cast(sum); -} +namespace detail { -// Returns the sum in each lane. -template -HWY_API VFromD SumOfLanes(D d, VFromD v) { - return Set(d, ReduceSum(d, v)); +// Used by generic_ops-inl +template +HWY_INLINE VFromD ReduceAcrossBlocks(D d, Func f, VFromD v) { + v = f(v, SwapAdjacentBlocks(v)); + return f(v, ReverseBlocks(d, v)); } -// Returns the minimum in each lane. -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_epi32(v.raw)); -} -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_epi64(v.raw)); -} -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_epu32(v.raw)); -} -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_epu64(v.raw)); -} -#if HWY_HAVE_FLOAT16 -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_ph(v.raw)); -} -#endif // HWY_HAVE_FLOAT16 -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_ps(v.raw)); -} -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_min_pd(v.raw)); -} -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - const RepartitionToWide d32; - const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); - const auto odd = ShiftRight<16>(BitCast(d32, v)); - const auto min = MinOfLanes(d32, Min(even, odd)); - // Also broadcast into odd lanes. - return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); -} -template -HWY_API VFromD MinOfLanes(D d, VFromD v) { - const RepartitionToWide d32; - // Sign-extend - const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); - const auto odd = ShiftRight<16>(BitCast(d32, v)); - const auto min = MinOfLanes(d32, Min(even, odd)); - // Also broadcast into odd lanes. - return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); -} +} // namespace detail -// Returns the maximum in each lane. -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_epi32(v.raw)); -} -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_epi64(v.raw)); -} -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_epu32(v.raw)); -} -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_epu64(v.raw)); -} -#if HWY_HAVE_FLOAT16 -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_ph(v.raw)); -} -#endif // HWY_HAVE_FLOAT16 -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_ps(v.raw)); -} -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - return Set(d, _mm512_reduce_max_pd(v.raw)); -} -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - const RepartitionToWide d32; - const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); - const auto odd = ShiftRight<16>(BitCast(d32, v)); - const auto min = MaxOfLanes(d32, Max(even, odd)); - // Also broadcast into odd lanes. - return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); -} -template -HWY_API VFromD MaxOfLanes(D d, VFromD v) { - const RepartitionToWide d32; - // Sign-extend - const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); - const auto odd = ShiftRight<16>(BitCast(d32, v)); - const auto min = MaxOfLanes(d32, Max(even, odd)); - // Also broadcast into odd lanes. - return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, 64)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + const __mmask64 mmask64_bit_shuf_result = + _mm512_bitshuffle_epi64_mask(v.raw, idx.raw); + +#if HWY_ARCH_X86_64 + const VFromD vu8_bit_shuf_result{ + _mm_cvtsi64_si128(static_cast(mmask64_bit_shuf_result))}; +#else + const int32_t i32_lo_bit_shuf_result = + static_cast(mmask64_bit_shuf_result); + const int32_t i32_hi_bit_shuf_result = + static_cast(_kshiftri_mask64(mmask64_bit_shuf_result, 32)); + + const VFromD vu8_bit_shuf_result = ResizeBitCast( + du8, InterleaveLower( + Vec128{_mm_cvtsi32_si128(i32_lo_bit_shuf_result)}, + Vec128{_mm_cvtsi32_si128(i32_hi_bit_shuf_result)})); +#endif + + return BitCast(d64, PromoteTo(du64, vu8_bit_shuf_result)); } +#endif // HWY_TARGET <= HWY_AVX3_DL // -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex diff --git a/r/src/vendor/highway/hwy/per_target.cc b/r/src/vendor/highway/hwy/per_target.cc index 63e69616..4f9de2e3 100644 --- a/r/src/vendor/highway/hwy/per_target.cc +++ b/r/src/vendor/highway/hwy/per_target.cc @@ -13,8 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Enable all targets so that calling Have* does not call into a null pointer. +#ifndef HWY_COMPILE_ALL_ATTAINABLE +#define HWY_COMPILE_ALL_ATTAINABLE +#endif #include "hwy/per_target.h" +#include +#include + #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "hwy/per_target.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep @@ -23,7 +30,9 @@ HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { +int64_t GetTarget() { return HWY_TARGET; } size_t GetVectorBytes() { return Lanes(ScalableTag()); } +bool GetHaveInteger64() { return HWY_HAVE_INTEGER64 != 0; } bool GetHaveFloat16() { return HWY_HAVE_FLOAT16 != 0; } bool GetHaveFloat64() { return HWY_HAVE_FLOAT64 != 0; } // NOLINTNEXTLINE(google-readability-namespace-comments) @@ -35,15 +44,25 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace hwy { namespace { +HWY_EXPORT(GetTarget); HWY_EXPORT(GetVectorBytes); +HWY_EXPORT(GetHaveInteger64); HWY_EXPORT(GetHaveFloat16); HWY_EXPORT(GetHaveFloat64); } // namespace +HWY_DLLEXPORT int64_t DispatchedTarget() { + return HWY_DYNAMIC_DISPATCH(GetTarget)(); +} + HWY_DLLEXPORT size_t VectorBytes() { return HWY_DYNAMIC_DISPATCH(GetVectorBytes)(); } +HWY_DLLEXPORT bool HaveInteger64() { + return HWY_DYNAMIC_DISPATCH(GetHaveInteger64)(); +} + HWY_DLLEXPORT bool HaveFloat16() { return HWY_DYNAMIC_DISPATCH(GetHaveFloat16)(); } diff --git a/r/src/vendor/highway/hwy/per_target.h b/r/src/vendor/highway/hwy/per_target.h index 52c316ec..7a86b0eb 100644 --- a/r/src/vendor/highway/hwy/per_target.h +++ b/r/src/vendor/highway/hwy/per_target.h @@ -17,6 +17,7 @@ #define HIGHWAY_HWY_PER_TARGET_H_ #include +#include #include "hwy/highway_export.h" @@ -25,6 +26,9 @@ namespace hwy { +// Returns the HWY_TARGET which HWY_DYNAMIC_DISPATCH selected. +HWY_DLLEXPORT int64_t DispatchedTarget(); + // Returns size in bytes of a vector, i.e. `Lanes(ScalableTag())`. // // Do not cache the result, which may change after calling DisableTargets, or @@ -35,7 +39,8 @@ namespace hwy { // unnecessarily. HWY_DLLEXPORT size_t VectorBytes(); -// Returns whether 16/64-bit floats are a supported lane type. +// Returns whether 64-bit integers, 16/64-bit floats are a supported lane type. +HWY_DLLEXPORT bool HaveInteger64(); HWY_DLLEXPORT bool HaveFloat16(); HWY_DLLEXPORT bool HaveFloat64(); diff --git a/r/src/vendor/highway/hwy/print.cc b/r/src/vendor/highway/hwy/print.cc index ac9d3f87..b2c68b14 100644 --- a/r/src/vendor/highway/hwy/print.cc +++ b/r/src/vendor/highway/hwy/print.cc @@ -40,16 +40,24 @@ HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100) { HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, char* string100) { if (info.sizeof_t == 1) { - uint8_t byte; - CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. - snprintf(string100, 100, "0x%02X", byte); // NOLINT + if (info.is_signed) { + int8_t byte; + CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. + snprintf(string100, 100, "%d", byte); // NOLINT + } else { + uint8_t byte; + CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. + snprintf(string100, 100, "0x%02X", byte); // NOLINT + } } else if (info.sizeof_t == 2) { if (info.is_bf16) { const double value = static_cast(F32FromBF16Mem(ptr)); - snprintf(string100, 100, "%.3f", value); // NOLINT + const char* fmt = hwy::ScalarAbs(value) < 1E-3 ? "%.3E" : "%.3f"; + snprintf(string100, 100, fmt, value); // NOLINT } else if (info.is_float) { const double value = static_cast(F32FromF16Mem(ptr)); - snprintf(string100, 100, "%.4f", value); // NOLINT + const char* fmt = hwy::ScalarAbs(value) < 1E-4 ? "%.4E" : "%.4f"; + snprintf(string100, 100, fmt, value); // NOLINT } else { uint16_t bits; CopyBytes<2>(ptr, &bits); @@ -59,7 +67,8 @@ HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, if (info.is_float) { float value; CopyBytes<4>(ptr, &value); - snprintf(string100, 100, "%.9f", static_cast(value)); // NOLINT + const char* fmt = hwy::ScalarAbs(value) < 1E-6 ? "%.9E" : "%.9f"; + snprintf(string100, 100, fmt, static_cast(value)); // NOLINT } else if (info.is_signed) { int32_t value; CopyBytes<4>(ptr, &value); @@ -69,12 +78,12 @@ HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, CopyBytes<4>(ptr, &value); snprintf(string100, 100, "%u", value); // NOLINT } - } else { - HWY_ASSERT(info.sizeof_t == 8); + } else if (info.sizeof_t == 8) { if (info.is_float) { double value; CopyBytes<8>(ptr, &value); - snprintf(string100, 100, "%.18f", value); // NOLINT + const char* fmt = hwy::ScalarAbs(value) < 1E-9 ? "%.18E" : "%.18f"; + snprintf(string100, 100, fmt, value); // NOLINT } else { const uint8_t* ptr8 = reinterpret_cast(ptr); uint32_t lo, hi; @@ -82,6 +91,17 @@ HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 4 : 0), &hi); snprintf(string100, 100, "0x%08x%08x", hi, lo); // NOLINT } + } else if (info.sizeof_t == 16) { + HWY_ASSERT(!info.is_float && !info.is_signed && !info.is_bf16); + const uint8_t* ptr8 = reinterpret_cast(ptr); + uint32_t words[4]; + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 0 : 12), &words[0]); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 4 : 8), &words[1]); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 8 : 4), &words[2]); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 12 : 0), &words[3]); + // NOLINTNEXTLINE + snprintf(string100, 100, "0x%08x%08x_%08x%08x", words[3], words[2], + words[1], words[0]); } } diff --git a/r/src/vendor/highway/hwy/profiler.h b/r/src/vendor/highway/hwy/profiler.h new file mode 100644 index 00000000..467ac0c4 --- /dev/null +++ b/r/src/vendor/highway/hwy/profiler.h @@ -0,0 +1,682 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PROFILER_H_ +#define HIGHWAY_HWY_PROFILER_H_ + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// Uses RAII to capture begin/end timestamps, with user-specified zone names: +// { PROFILER_ZONE("name"); /*code*/ } or +// the name of the current function: +// void FuncToMeasure() { PROFILER_FUNC; /*code*/ }. +// +// After all threads have exited any zones, invoke PROFILER_PRINT_RESULTS() to +// print call counts and average durations [CPU cycles] to stdout, sorted in +// descending order of total duration. +// +// The binary MUST be built with --dynamic_mode=off because we rely on the data +// segments being nearby; if not, an assertion will likely fail. + +#include "hwy/base.h" + +// Configuration settings: + +// If zero, this file has no effect and no measurements will be recorded. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif + +// How many mebibytes to allocate (if PROFILER_ENABLED) per thread that +// enters at least one zone. Once this buffer is full, the thread will analyze +// and discard packets, thus temporarily adding some observer overhead. +// Each zone occupies 16 bytes. +#ifndef PROFILER_THREAD_STORAGE +#define PROFILER_THREAD_STORAGE 200ULL +#endif + +#if PROFILER_ENABLED || HWY_IDE + +#include +#include +#include +#include // strcmp + +#include // std::sort +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/cache_control.h" // FlushStream +// #include "hwy/contrib/sort/vqsort.h" +#include "hwy/robust_statistics.h" +#include "hwy/timer-inl.h" +#include "hwy/timer.h" + +#define PROFILER_PRINT_OVERHEAD 0 + +namespace hwy { + +// Upper bounds for fixed-size data structures (guarded via HWY_DASSERT): + +// How many threads can actually enter a zone (those that don't do not count). +// Memory use is about kMaxThreads * PROFILER_THREAD_STORAGE MiB. +// WARNING: a fiber library can spawn hundreds of threads. +static constexpr size_t kMaxThreads = 256; + +static constexpr size_t kMaxDepth = 64; // Maximum nesting of zones. + +static constexpr size_t kMaxZones = 256; // Total number of zones. + +#pragma pack(push, 1) + +// Represents zone entry/exit events. Stores a full-resolution timestamp plus +// an offset (representing zone name or identifying exit packets). POD. +class Packet { + public: + // If offsets do not fit, UpdateOrAdd will overrun our heap allocation + // (governed by kMaxZones). We have seen multi-megabyte offsets. + static constexpr size_t kOffsetBits = 25; + static constexpr uint64_t kOffsetBias = 1ULL << (kOffsetBits - 1); + + // We need full-resolution timestamps; at an effective rate of 4 GHz, + // this permits 1 minute zone durations (for longer durations, split into + // multiple zones). Wraparound is handled by masking. + static constexpr size_t kTimestampBits = 64 - kOffsetBits; + static constexpr uint64_t kTimestampMask = (1ULL << kTimestampBits) - 1; + + static Packet Make(const size_t biased_offset, const uint64_t timestamp) { + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << kOffsetBits)); + + Packet packet; + packet.bits_ = + (biased_offset << kTimestampBits) + (timestamp & kTimestampMask); + + HWY_DASSERT(packet.BiasedOffset() == biased_offset); + HWY_DASSERT(packet.Timestamp() == (timestamp & kTimestampMask)); + return packet; + } + + uint64_t Timestamp() const { return bits_ & kTimestampMask; } + + size_t BiasedOffset() const { + const size_t biased_offset = (bits_ >> kTimestampBits); + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << kOffsetBits)); + return biased_offset; + } + + private: + uint64_t bits_; +}; +static_assert(sizeof(Packet) == 8, "Wrong Packet size"); + +// All translation units must use the same string origin. A static member +// function ensures this without requiring a separate .cc file. +struct StringOrigin { + // Returns the address of a string literal. Assuming zone names are also + // literals and stored nearby, we can represent them as offsets from this, + // which is faster to compute than hashes or even a static index. + static const char* Get() { + // Chosen such that no zone name is a prefix nor suffix of this string + // to ensure they aren't merged. Note zone exit packets use + // `biased_offset == kOffsetBias`. + static const char* string_origin = "__#__"; + return string_origin - Packet::kOffsetBias; + } +}; + +// Representation of an active zone, stored in a stack. Used to deduct +// child duration from the parent's self time. POD. +struct Node { + Packet packet; + uint64_t child_total; +}; +static_assert(sizeof(Node) == 16, "Wrong Node size"); + +// Holds statistics for all zones with the same name. POD. +struct Accumulator { + static constexpr size_t kNumCallBits = 64 - Packet::kOffsetBits; + + uint64_t BiasedOffset() const { + const size_t biased_offset = u128.lo >> kNumCallBits; + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + return biased_offset; + } + uint64_t NumCalls() const { return u128.lo & ((1ULL << kNumCallBits) - 1); } + uint64_t Duration() const { return u128.hi; } + + void Set(uint64_t biased_offset, uint64_t num_calls, uint64_t duration) { + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + HWY_DASSERT(num_calls < (1ULL << kNumCallBits)); + + u128.hi = duration; + u128.lo = (biased_offset << kNumCallBits) + num_calls; + + HWY_DASSERT(BiasedOffset() == biased_offset); + HWY_DASSERT(NumCalls() == num_calls); + HWY_DASSERT(Duration() == duration); + } + + void Add(uint64_t num_calls, uint64_t duration) { + const uint64_t biased_offset = BiasedOffset(); + (void)biased_offset; + + u128.lo += num_calls; + u128.hi += duration; + + HWY_DASSERT(biased_offset == BiasedOffset()); + } + + // For fast sorting by duration, which must therefore be the hi element. + // lo holds BiasedOffset and NumCalls. + uint128_t u128; +}; +static_assert(sizeof(Accumulator) == 16, "Wrong Accumulator size"); + +template +inline T ClampedSubtract(const T minuend, const T subtrahend) { + if (subtrahend > minuend) { + return 0; + } + return minuend - subtrahend; +} + +// Per-thread call graph (stack) and Accumulator for each zone. +class Results { + public: + Results() { + ZeroBytes(nodes_, sizeof(nodes_)); + ZeroBytes(zones_, sizeof(zones_)); + } + + // Used for computing overhead when this thread encounters its first Zone. + // This has no observable effect apart from increasing "analyze_elapsed_". + uint64_t ZoneDuration(const Packet* packets) { + HWY_DASSERT(depth_ == 0); + HWY_DASSERT(num_zones_ == 0); + AnalyzePackets(packets, 2); + const uint64_t duration = zones_[0].Duration(); + zones_[0].Set(1, 0, 0); // avoids triggering biased_offset = 0 checks + HWY_DASSERT(depth_ == 0); + num_zones_ = 0; + return duration; + } + + void SetSelfOverhead(const uint64_t self_overhead) { + self_overhead_ = self_overhead; + } + + void SetChildOverhead(const uint64_t child_overhead) { + child_overhead_ = child_overhead; + } + + // Draw all required information from the packets, which can be discarded + // afterwards. Called whenever this thread's storage is full. + void AnalyzePackets(const Packet* packets, const size_t num_packets) { + namespace hn = HWY_NAMESPACE; + const uint64_t t0 = hn::timer::Start(); + + for (size_t i = 0; i < num_packets; ++i) { + const Packet p = packets[i]; + // Entering a zone + if (p.BiasedOffset() != Packet::kOffsetBias) { + HWY_DASSERT(depth_ < kMaxDepth); + nodes_[depth_].packet = p; + HWY_DASSERT(p.BiasedOffset() != 0); + nodes_[depth_].child_total = 0; + ++depth_; + continue; + } + + HWY_DASSERT(depth_ != 0); + const Node& node = nodes_[depth_ - 1]; + // Masking correctly handles unsigned wraparound. + const uint64_t duration = + (p.Timestamp() - node.packet.Timestamp()) & Packet::kTimestampMask; + const uint64_t self_duration = ClampedSubtract( + duration, self_overhead_ + child_overhead_ + node.child_total); + + UpdateOrAdd(node.packet.BiasedOffset(), 1, self_duration); + --depth_; + + // Deduct this nested node's time from its parent's self_duration. + if (depth_ != 0) { + nodes_[depth_ - 1].child_total += duration + child_overhead_; + } + } + + const uint64_t t1 = hn::timer::Stop(); + analyze_elapsed_ += t1 - t0; + } + + // Incorporates results from another thread. Call after all threads have + // exited any zones. + void Assimilate(Results& other) { + namespace hn = HWY_NAMESPACE; + const uint64_t t0 = hn::timer::Start(); + HWY_DASSERT(depth_ == 0); + HWY_DASSERT(other.depth_ == 0); + + for (size_t i = 0; i < other.num_zones_; ++i) { + const Accumulator& zone = other.zones_[i]; + UpdateOrAdd(zone.BiasedOffset(), zone.NumCalls(), zone.Duration()); + } + other.num_zones_ = 0; + const uint64_t t1 = hn::timer::Stop(); + analyze_elapsed_ += t1 - t0 + other.analyze_elapsed_; + } + + // Single-threaded. + void Print() { + namespace hn = HWY_NAMESPACE; + const uint64_t t0 = hn::timer::Start(); + MergeDuplicates(); + + // Sort by decreasing total (self) cost. + // VQSort(&zones_[0].u128, num_zones_, SortDescending()); + std::sort(zones_, zones_ + num_zones_, + [](const Accumulator& z1, const Accumulator& z2) { + return z1.Duration() > z2.Duration(); + }); + + const double inv_freq = 1.0 / platform::InvariantTicksPerSecond(); + + const char* string_origin = StringOrigin::Get(); + for (size_t i = 0; i < num_zones_; ++i) { + const Accumulator& z = zones_[i]; + const size_t num_calls = z.NumCalls(); + const double duration = static_cast(z.Duration()); + printf("%-40s: %10zu x %15.0f = %9.6f\n", + string_origin + z.BiasedOffset(), num_calls, duration / num_calls, + duration * inv_freq); + } + num_zones_ = 0; + + const uint64_t t1 = hn::timer::Stop(); + analyze_elapsed_ += t1 - t0; + printf("Total analysis [s]: %f\n", + static_cast(analyze_elapsed_) * inv_freq); + } + + private: + // Updates an existing Accumulator (uniquely identified by biased_offset) or + // adds one if this is the first time this thread analyzed that zone. + // Uses a self-organizing list data structure, which avoids dynamic memory + // allocations and is far faster than unordered_map. + void UpdateOrAdd(const size_t biased_offset, const uint64_t num_calls, + const uint64_t duration) { + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + + // Special case for first zone: (maybe) update, without swapping. + if (num_zones_ != 0 && zones_[0].BiasedOffset() == biased_offset) { + zones_[0].Add(num_calls, duration); + return; + } + + // Look for a zone with the same offset. + for (size_t i = 1; i < num_zones_; ++i) { + if (zones_[i].BiasedOffset() == biased_offset) { + zones_[i].Add(num_calls, duration); + // Swap with predecessor (more conservative than move to front, + // but at least as successful). + const Accumulator prev = zones_[i - 1]; + zones_[i - 1] = zones_[i]; + zones_[i] = prev; + return; + } + } + + // Not found; create a new Accumulator. + HWY_DASSERT(num_zones_ < kMaxZones); + zones_[num_zones_].Set(biased_offset, num_calls, duration); + ++num_zones_; + } + + // Each instantiation of a function template seems to get its own copy of + // __func__ and GCC doesn't merge them. An N^2 search for duplicates is + // acceptable because we only expect a few dozen zones. + void MergeDuplicates() { + const char* string_origin = StringOrigin::Get(); + for (size_t i = 0; i < num_zones_; ++i) { + const size_t biased_offset = zones_[i].BiasedOffset(); + const char* name = string_origin + biased_offset; + // Separate num_calls from biased_offset so we can add them together. + uint64_t num_calls = zones_[i].NumCalls(); + + // Add any subsequent duplicates to num_calls and total_duration. + for (size_t j = i + 1; j < num_zones_;) { + if (!strcmp(name, string_origin + zones_[j].BiasedOffset())) { + num_calls += zones_[j].NumCalls(); + zones_[i].Add(0, zones_[j].Duration()); + // j was the last zone, so we are done. + if (j == num_zones_ - 1) break; + // Replace current zone with the last one, and check it next. + zones_[j] = zones_[--num_zones_]; + } else { // Name differed, try next Accumulator. + ++j; + } + } + + // Re-pack regardless of whether any duplicates were found. + zones_[i].Set(biased_offset, num_calls, zones_[i].Duration()); + } + } + + uint64_t analyze_elapsed_ = 0; + uint64_t self_overhead_ = 0; + uint64_t child_overhead_ = 0; + + size_t depth_ = 0; // Number of active zones. + size_t num_zones_ = 0; // Number of retired zones. + + alignas(HWY_ALIGNMENT) Node nodes_[kMaxDepth]; // Stack + alignas(HWY_ALIGNMENT) Accumulator zones_[kMaxZones]; // Self-organizing list +}; + +// Per-thread packet storage, dynamically allocated. +class ThreadSpecific { + static constexpr size_t kBufferCapacity = HWY_ALIGNMENT / sizeof(Packet); + + public: + // "name" is used to sanity-check offsets fit in kOffsetBits. + explicit ThreadSpecific(const char* name) + : max_packets_((PROFILER_THREAD_STORAGE << 20) / sizeof(Packet)), + packets_(AllocateAligned(max_packets_)), + num_packets_(0), + string_origin_(StringOrigin::Get()) { + // Even in optimized builds, verify that this zone's name offset fits + // within the allotted space. If not, UpdateOrAdd is likely to overrun + // zones_[]. Checking here on the cold path (only reached once per thread) + // is cheap, but it only covers one zone. + const size_t biased_offset = name - string_origin_; + HWY_ASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + } + + // Depends on Zone => defined below. + void ComputeOverhead(); + + void WriteEntry(const char* name, const uint64_t timestamp) { + HWY_DASSERT(name >= string_origin_); + const size_t biased_offset = static_cast(name - string_origin_); + Write(Packet::Make(biased_offset, timestamp)); + } + + void WriteExit(const uint64_t timestamp) { + const size_t biased_offset = Packet::kOffsetBias; + Write(Packet::Make(biased_offset, timestamp)); + } + + void AnalyzeRemainingPackets() { + // Ensures prior weakly-ordered streaming stores are globally visible. + FlushStream(); + + // Storage full => empty it. + if (num_packets_ + buffer_size_ > max_packets_) { + results_.AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + CopyBytes(buffer_, packets_.get() + num_packets_, + buffer_size_ * sizeof(Packet)); + num_packets_ += buffer_size_; + + results_.AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + + Results& GetResults() { return results_; } + + private: + // Overwrites "to" while attempting to bypass the cache (read-for-ownership). + // Both pointers must be aligned. + static void StreamCacheLine(const uint64_t* HWY_RESTRICT from, + uint64_t* HWY_RESTRICT to) { +#if HWY_COMPILER_CLANG + for (size_t i = 0; i < HWY_ALIGNMENT / sizeof(uint64_t); ++i) { + __builtin_nontemporal_store(from[i], to + i); + } +#else + hwy::CopyBytes(from, to, HWY_ALIGNMENT); +#endif + } + + // Write packet to buffer/storage, emptying them as needed. + void Write(const Packet packet) { + // Buffer full => copy to storage. + if (buffer_size_ == kBufferCapacity) { + // Storage full => empty it. + if (num_packets_ + kBufferCapacity > max_packets_) { + results_.AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + // This buffering halves observer overhead and decreases the overall + // runtime by about 3%. Casting is safe because the first member is u64. + StreamCacheLine( + reinterpret_cast(buffer_), + reinterpret_cast(packets_.get() + num_packets_)); + num_packets_ += kBufferCapacity; + buffer_size_ = 0; + } + buffer_[buffer_size_] = packet; + ++buffer_size_; + } + + // Write-combining buffer to avoid cache pollution. Must be the first + // non-static member to ensure cache-line alignment. + Packet buffer_[kBufferCapacity]; + size_t buffer_size_ = 0; + + const size_t max_packets_; + // Contiguous storage for zone enter/exit packets. + AlignedFreeUniquePtr packets_; + size_t num_packets_; + // Cached here because we already read this cache line on zone entry/exit. + const char* string_origin_; + Results results_; +}; + +class ThreadList { + public: + // Called from any thread. + ThreadSpecific* Add(const char* name) { + const size_t index = num_threads_.fetch_add(1, std::memory_order_relaxed); + HWY_DASSERT(index < kMaxThreads); + + ThreadSpecific* ts = MakeUniqueAligned(name).release(); + threads_[index].store(ts, std::memory_order_release); + return ts; + } + + // Single-threaded. + void PrintResults() { + const auto acq = std::memory_order_acquire; + const size_t num_threads = num_threads_.load(acq); + + ThreadSpecific* main = threads_[0].load(acq); + main->AnalyzeRemainingPackets(); + + for (size_t i = 1; i < num_threads; ++i) { + ThreadSpecific* ts = threads_[i].load(acq); + ts->AnalyzeRemainingPackets(); + main->GetResults().Assimilate(ts->GetResults()); + } + + if (num_threads != 0) { + main->GetResults().Print(); + } + } + + private: + // Owning pointers. + alignas(64) std::atomic threads_[kMaxThreads]; + std::atomic num_threads_{0}; +}; + +// RAII zone enter/exit recorder constructed by the ZONE macro; also +// responsible for initializing ThreadSpecific. +class Zone { + public: + // "name" must be a string literal (see StringOrigin::Get). + HWY_NOINLINE explicit Zone(const char* name) { + HWY_FENCE; + ThreadSpecific* HWY_RESTRICT thread_specific = StaticThreadSpecific(); + if (HWY_UNLIKELY(thread_specific == nullptr)) { + // Ensure the CPU supports our timer. + char cpu[100]; + if (!platform::HaveTimerStop(cpu)) { + HWY_ABORT("CPU %s is too old for PROFILER_ENABLED=1, exiting", cpu); + } + + thread_specific = StaticThreadSpecific() = Threads().Add(name); + // Must happen after setting StaticThreadSpecific, because ComputeOverhead + // also calls Zone(). + thread_specific->ComputeOverhead(); + } + + // (Capture timestamp ASAP, not inside WriteEntry.) + HWY_FENCE; + const uint64_t timestamp = HWY_NAMESPACE::timer::Start(); + thread_specific->WriteEntry(name, timestamp); + } + + HWY_NOINLINE ~Zone() { + HWY_FENCE; + const uint64_t timestamp = HWY_NAMESPACE::timer::Stop(); + StaticThreadSpecific()->WriteExit(timestamp); + HWY_FENCE; + } + + // Call exactly once after all threads have exited all zones. + static void PrintResults() { Threads().PrintResults(); } + + private: + // Returns reference to the thread's ThreadSpecific pointer (initially null). + // Function-local static avoids needing a separate definition. + static ThreadSpecific*& StaticThreadSpecific() { + static thread_local ThreadSpecific* thread_specific; + return thread_specific; + } + + // Returns the singleton ThreadList. Non time-critical. + static ThreadList& Threads() { + static ThreadList threads_; + return threads_; + } +}; + +// Creates a zone starting from here until the end of the current scope. +// Timestamps will be recorded when entering and exiting the zone. +// "name" must be a string literal, which is ensured by merging with "". +#define PROFILER_ZONE(name) \ + HWY_FENCE; \ + const hwy::Zone zone("" name); \ + HWY_FENCE + +// Creates a zone for an entire function (when placed at its beginning). +// Shorter/more convenient than ZONE. +#define PROFILER_FUNC \ + HWY_FENCE; \ + const hwy::Zone zone(__func__); \ + HWY_FENCE + +#define PROFILER_PRINT_RESULTS hwy::Zone::PrintResults + +inline void ThreadSpecific::ComputeOverhead() { + namespace hn = HWY_NAMESPACE; + // Delay after capturing timestamps before/after the actual zone runs. Even + // with frequency throttling disabled, this has a multimodal distribution, + // including 32, 34, 48, 52, 59, 62. + uint64_t self_overhead; + { + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 1024; + uint32_t durations[kNumDurations]; + + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + { + PROFILER_ZONE("Dummy Zone (never shown)"); + } + const uint64_t duration = results_.ZoneDuration(buffer_); + buffer_size_ = 0; + durations[idx_duration] = static_cast(duration); + HWY_DASSERT(num_packets_ == 0); + } + robust_statistics::CountingSort(durations, kNumDurations); + samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations); + } + // Median. + robust_statistics::CountingSort(samples, kNumSamples); + self_overhead = samples[kNumSamples / 2]; + if (PROFILER_PRINT_OVERHEAD) { + printf("Overhead: %.0f\n", static_cast(self_overhead)); + } + results_.SetSelfOverhead(self_overhead); + } + + // Delay before capturing start timestamp / after end timestamp. + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 16; + uint32_t durations[kNumDurations]; + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + const size_t kReps = 10000; + // Analysis time should not be included => must fit within buffer. + HWY_DASSERT(kReps * 2 < max_packets_); + std::atomic_thread_fence(std::memory_order_seq_cst); + const uint64_t t0 = hn::timer::Start(); + for (size_t i = 0; i < kReps; ++i) { + PROFILER_ZONE("Dummy"); + } + FlushStream(); + const uint64_t t1 = hn::timer::Stop(); + HWY_DASSERT(num_packets_ + buffer_size_ == kReps * 2); + buffer_size_ = 0; + num_packets_ = 0; + const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps; + durations[idx_duration] = + static_cast(ClampedSubtract(avg_duration, self_overhead)); + } + robust_statistics::CountingSort(durations, kNumDurations); + samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations); + } + robust_statistics::CountingSort(samples, kNumSamples); + const uint64_t child_overhead = samples[9 * kNumSamples / 10]; + if (PROFILER_PRINT_OVERHEAD) { + printf("Child overhead: %.0f\n", static_cast(child_overhead)); + } + results_.SetChildOverhead(child_overhead); +} + +#pragma pack(pop) + +} // namespace hwy + +#endif // PROFILER_ENABLED || HWY_IDE + +#if !PROFILER_ENABLED && !HWY_IDE +#define PROFILER_ZONE(name) +#define PROFILER_FUNC +#define PROFILER_PRINT_RESULTS() +#endif + +#endif // HIGHWAY_HWY_PROFILER_H_ diff --git a/r/src/vendor/highway/hwy/robust_statistics.h b/r/src/vendor/highway/hwy/robust_statistics.h index 1cf3e5d2..d80a2146 100644 --- a/r/src/vendor/highway/hwy/robust_statistics.h +++ b/r/src/vendor/highway/hwy/robust_statistics.h @@ -135,8 +135,8 @@ T MedianAbsoluteDeviation(const T* values, const size_t num_values, std::vector abs_deviations; abs_deviations.reserve(num_values); for (size_t i = 0; i < num_values; ++i) { - const int64_t abs = std::abs(static_cast(values[i]) - - static_cast(median)); + const int64_t abs = ScalarAbs(static_cast(values[i]) - + static_cast(median)); abs_deviations.push_back(static_cast(abs)); } return Median(abs_deviations.data(), num_values); diff --git a/r/src/vendor/highway/hwy/stats.cc b/r/src/vendor/highway/hwy/stats.cc new file mode 100644 index 00000000..4c53124b --- /dev/null +++ b/r/src/vendor/highway/hwy/stats.cc @@ -0,0 +1,120 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/stats.h" + +#include + +#include // std::min +#include + +#include "hwy/base.h" // HWY_ASSERT + +namespace hwy { + +void Stats::Assimilate(const Stats& other) { + const int64_t total_n = n_ + other.n_; + if (total_n == 0) return; // Nothing to do; prevents div by zero. + + min_ = std::min(min_, other.min_); + max_ = std::max(max_, other.max_); + + sum_log_ += other.sum_log_; + + const double product_n = n_ * other.n_; + const double n2 = n_ * n_; + const double other_n2 = other.n_ * other.n_; + const int64_t total_n2 = total_n * total_n; + const double total_n3 = static_cast(total_n2) * total_n; + // Precompute reciprocal for speed - used at least twice. + const double inv_total_n = 1.0 / total_n; + const double inv_total_n2 = 1.0 / total_n2; + + const double delta = other.m1_ - m1_; + const double delta2 = delta * delta; + const double delta3 = delta * delta2; + const double delta4 = delta2 * delta2; + + m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n; + + const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n; + + const double new_m3 = + m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 + + 3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n; + + m4_ += other.m4_ + + delta4 * product_n * (n2 - product_n + other_n2) / total_n3 + + 6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 + + 4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n; + + m2_ = new_m2; + m3_ = new_m3; + n_ = total_n; +} + +std::string Stats::ToString(int exclude) const { + if (Count() == 0) return std::string("(none)"); + + char buf[300]; + int pos = 0; + int ret; // snprintf - bytes written or negative for error. + + if ((exclude & kNoCount) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ", + static_cast(Count())); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMeanSD) == 0) { + const float sd = StandardDeviation(); + if (sd > 100) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.2e SD=%7.1e ", + Mean(), sd); + } else { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.6e SD=%7.5e ", + Mean(), sd); + } + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMinMax) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%8.5e Max=%8.5e ", Min(), + Max()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoSkewKurt) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ", + Skewness(), Kurtosis()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoGeomean) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ", + GeometricMean()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + HWY_ASSERT(pos < static_cast(sizeof(buf))); + return buf; +} + +} // namespace hwy diff --git a/r/src/vendor/highway/hwy/stats.h b/r/src/vendor/highway/hwy/stats.h new file mode 100644 index 00000000..207ad2bf --- /dev/null +++ b/r/src/vendor/highway/hwy/stats.h @@ -0,0 +1,194 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_STATS_H_ +#define HIGHWAY_HWY_STATS_H_ + +#include +#include + +#include +#include + +#include "hwy/base.h" // HWY_ASSERT + +namespace hwy { + +// Thread-compatible. +template +class Bins { + public: + Bins() { Reset(); } + + template + void Notify(T bin) { + HWY_ASSERT(T{0} <= bin && bin < static_cast(N)); + counts_[static_cast(bin)]++; + } + + void Assimilate(const Bins& other) { + for (size_t i = 0; i < N; ++i) { + counts_[i] += other.counts_[i]; + } + } + + void Print(const char* caption) const { + fprintf(stderr, "\n%s [%zu]\n", caption, N); + size_t last_nonzero = 0; + for (size_t i = N - 1; i < N; --i) { + if (counts_[i] != 0) { + last_nonzero = i; + break; + } + } + for (size_t i = 0; i <= last_nonzero; ++i) { + fprintf(stderr, " %zu\n", counts_[i]); + } + } + + void Reset() { + for (size_t i = 0; i < N; ++i) { + counts_[i] = 0; + } + } + + private: + size_t counts_[N]; +}; + +// Descriptive statistics of a variable (4 moments). Thread-compatible. +class Stats { + public: + Stats() { Reset(); } + + void Notify(const float x) { + ++n_; + + min_ = HWY_MIN(min_, x); + max_ = HWY_MAX(max_, x); + + // Logarithmic transform avoids/delays underflow and overflow. + sum_log_ += std::log(static_cast(x)); + + // Online moments. Reference: https://goo.gl/9ha694 + const double d = x - m1_; + const double d_div_n = d / n_; + const double d2n1_div_n = d * (n_ - 1) * d_div_n; + const int64_t n_poly = n_ * n_ - 3 * n_ + 3; + m1_ += d_div_n; + m4_ += d_div_n * (d_div_n * (d2n1_div_n * n_poly + 6.0 * m2_) - 4.0 * m3_); + m3_ += d_div_n * (d2n1_div_n * (n_ - 2) - 3.0 * m2_); + m2_ += d2n1_div_n; + } + + void Assimilate(const Stats& other); + + int64_t Count() const { return n_; } + + float Min() const { return min_; } + float Max() const { return max_; } + + double GeometricMean() const { + return n_ == 0 ? 0.0 : std::exp(sum_log_ / n_); + } + + double Mean() const { return m1_; } + // Same as Mu2. Assumes n_ is large. + double SampleVariance() const { + return n_ == 0 ? 0.0 : m2_ / static_cast(n_); + } + // Unbiased estimator for population variance even for smaller n_. + double Variance() const { + if (n_ == 0) return 0.0; + if (n_ == 1) return m2_; + return m2_ / static_cast(n_ - 1); + } + double StandardDeviation() const { return std::sqrt(Variance()); } + // Near zero for normal distributions; if positive on a unimodal distribution, + // the right tail is fatter. Assumes n_ is large. + double SampleSkewness() const { + if (ScalarAbs(m2_) < 1E-7) return 0.0; + return m3_ * std::sqrt(static_cast(n_)) / std::pow(m2_, 1.5); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Skewness() const { + if (n_ == 0) return 0.0; + const double biased = SampleSkewness(); + const double r = (n_ - 1.0) / n_; + return biased * std::pow(r, 1.5); + } + // Near zero for normal distributions; smaller values indicate fewer/smaller + // outliers and larger indicates more/larger outliers. Assumes n_ is large. + double SampleKurtosis() const { + if (ScalarAbs(m2_) < 1E-7) return 0.0; + return m4_ * n_ / (m2_ * m2_); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Kurtosis() const { + if (n_ == 0) return 0.0; + const double biased = SampleKurtosis(); + const double r = (n_ - 1.0) / n_; + return biased * r * r; + } + + // Central moments, useful for "method of moments"-based parameter estimation + // of a mixture of two Gaussians. Assumes Count() != 0. + double Mu1() const { return m1_; } + double Mu2() const { return m2_ / static_cast(n_); } + double Mu3() const { return m3_ / static_cast(n_); } + double Mu4() const { return m4_ / static_cast(n_); } + + // Which statistics to EXCLUDE in ToString + enum { + kNoCount = 1, + kNoMeanSD = 2, + kNoMinMax = 4, + kNoSkewKurt = 8, + kNoGeomean = 16 + }; + std::string ToString(int exclude = 0) const; + + void Reset() { + n_ = 0; + + min_ = hwy::HighestValue(); + max_ = hwy::LowestValue(); + + sum_log_ = 0.0; + + m1_ = 0.0; + m2_ = 0.0; + m3_ = 0.0; + m4_ = 0.0; + } + + private: + int64_t n_; // signed for faster conversion + safe subtraction + + float min_; + float max_; + + double sum_log_; // for geomean + + // Moments + double m1_; + double m2_; + double m3_; + double m4_; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_STATS_H_ diff --git a/r/src/vendor/highway/hwy/targets.cc b/r/src/vendor/highway/hwy/targets.cc index e68f754d..b6c2419b 100644 --- a/r/src/vendor/highway/hwy/targets.cc +++ b/r/src/vendor/highway/hwy/targets.cc @@ -15,17 +15,14 @@ #include "hwy/targets.h" -#include +#include #include -#include // abort / exit +#include "hwy/base.h" +#include "hwy/detect_targets.h" #include "hwy/highway.h" #include "hwy/per_target.h" // VectorBytes -#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN -#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace -#endif - #if HWY_ARCH_X86 #include #if HWY_COMPILER_MSVC @@ -34,18 +31,24 @@ #include #endif // HWY_COMPILER_MSVC -#elif (HWY_ARCH_ARM || HWY_ARCH_PPC) && HWY_OS_LINUX +#elif (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X || HWY_ARCH_RISCV) && \ + HWY_OS_LINUX // sys/auxv.h does not always include asm/hwcap.h, or define HWCAP*, hence we // still include this directly. See #1199. #ifndef TOOLCHAIN_MISS_ASM_HWCAP_H #include #endif -#ifndef TOOLCHAIN_MISS_SYS_AUXV_H +#if HWY_HAVE_AUXV #include #endif #endif // HWY_ARCH_* +#if HWY_OS_APPLE +#include +#include +#endif // HWY_OS_APPLE + namespace hwy { namespace { @@ -56,6 +59,71 @@ int64_t supported_targets_for_test_ = 0; // Mask of targets disabled at runtime with DisableTargets. int64_t supported_mask_ = LimitsMax(); +#if HWY_OS_APPLE +static HWY_INLINE HWY_MAYBE_UNUSED bool HasCpuFeature( + const char* feature_name) { + int result = 0; + size_t len = sizeof(int); + return (sysctlbyname(feature_name, &result, &len, nullptr, 0) == 0 && + result != 0); +} + +static HWY_INLINE HWY_MAYBE_UNUSED bool ParseU32(const char*& ptr, + uint32_t& parsed_val) { + uint64_t parsed_u64 = 0; + + const char* start_ptr = ptr; + for (char ch; (ch = (*ptr)) != '\0'; ++ptr) { + unsigned digit = static_cast(static_cast(ch)) - + static_cast(static_cast('0')); + if (digit > 9u) { + break; + } + + parsed_u64 = (parsed_u64 * 10u) + digit; + if (parsed_u64 > 0xFFFFFFFFu) { + return false; + } + } + + parsed_val = static_cast(parsed_u64); + return (ptr != start_ptr); +} + +static HWY_INLINE HWY_MAYBE_UNUSED bool IsMacOs12_2OrLater() { + utsname uname_buf; + ZeroBytes(&uname_buf, sizeof(utsname)); + + if ((uname(&uname_buf)) != 0) { + return false; + } + + const char* ptr = uname_buf.release; + if (!ptr) { + return false; + } + + uint32_t major; + uint32_t minor; + if (!ParseU32(ptr, major)) { + return false; + } + + if (*ptr != '.') { + return false; + } + + ++ptr; + if (!ParseU32(ptr, minor)) { + return false; + } + + // We are running on macOS 12.2 or later if the Darwin kernel version is 21.3 + // or later + return (major > 21 || (major == 21 && minor >= 3)); +} +#endif // HWY_OS_APPLE + #if HWY_ARCH_X86 && HWY_HAVE_RUNTIME_DISPATCH namespace x86 { @@ -136,6 +204,7 @@ enum class FeatureIndex : uint32_t { kAVX512DQ, kAVX512BW, kAVX512FP16, + kAVX512BF16, kVNNI, kVPCLMULQDQ, @@ -203,6 +272,9 @@ uint64_t FlagsFromCPUID() { flags |= IsBitSet(abcd[2], 14) ? Bit(FeatureIndex::kPOPCNTDQ) : 0; flags |= IsBitSet(abcd[3], 23) ? Bit(FeatureIndex::kAVX512FP16) : 0; + + Cpuid(7, 1, abcd); + flags |= IsBitSet(abcd[0], 5) ? Bit(FeatureIndex::kAVX512BF16) : 0; } return flags; @@ -252,14 +324,17 @@ constexpr uint64_t kGroupAVX3_DL = Bit(FeatureIndex::kVAES) | Bit(FeatureIndex::kPOPCNTDQ) | Bit(FeatureIndex::kBITALG) | Bit(FeatureIndex::kGFNI) | kGroupAVX3; +constexpr uint64_t kGroupAVX3_ZEN4 = + Bit(FeatureIndex::kAVX512BF16) | kGroupAVX3_DL; + constexpr uint64_t kGroupAVX3_SPR = - Bit(FeatureIndex::kAVX512FP16) | kGroupAVX3_DL; + Bit(FeatureIndex::kAVX512FP16) | kGroupAVX3_ZEN4; int64_t DetectTargets() { int64_t bits = 0; // return value of supported targets. -#if HWY_ARCH_X86_64 - bits |= HWY_SSE2; // always present in x64 -#endif + HWY_IF_CONSTEXPR(HWY_ARCH_X86_64) { + bits |= HWY_SSE2; // always present in x64 + } const uint64_t flags = FlagsFromCPUID(); // Set target bit(s) if all their group's flags are all set. @@ -281,48 +356,93 @@ int64_t DetectTargets() { if ((flags & kGroupSSSE3) == kGroupSSSE3) { bits |= HWY_SSSE3; } -#if HWY_ARCH_X86_32 - if ((flags & kGroupSSE2) == kGroupSSE2) { - bits |= HWY_SSE2; + HWY_IF_CONSTEXPR(HWY_ARCH_X86_32) { + if ((flags & kGroupSSE2) == kGroupSSE2) { + bits |= HWY_SSE2; + } } -#endif - // Clear bits if the OS does not support XSAVE - otherwise, registers - // are not preserved across context switches. + // Clear AVX2/AVX3 bits if the CPU or OS does not support XSAVE - otherwise, + // YMM/ZMM registers are not preserved across context switches. + + // The lower 128 bits of XMM0-XMM15 are guaranteed to be preserved across + // context switches on x86_64 + + // The following OS's are known to preserve the lower 128 bits of XMM + // registers across context switches on x86 CPU's that support SSE (even in + // 32-bit mode): + // - Windows 2000 or later + // - Linux 2.4.0 or later + // - Mac OS X 10.4 or later + // - FreeBSD 4.4 or later + // - NetBSD 1.6 or later + // - OpenBSD 3.5 or later + // - UnixWare 7 Release 7.1.1 or later + // - Solaris 9 4/04 or later + uint32_t abcd[4]; Cpuid(1, 0, abcd); + const bool has_xsave = IsBitSet(abcd[2], 26); const bool has_osxsave = IsBitSet(abcd[2], 27); - if (has_osxsave) { - const uint32_t xcr0 = ReadXCR0(); - const int64_t min_avx3 = HWY_AVX3 | HWY_AVX3_DL | HWY_AVX3_SPR; - const int64_t min_avx2 = HWY_AVX2 | min_avx3; - // XMM - if (!IsBitSet(xcr0, 1)) { -#if HWY_ARCH_X86_64 - // The HWY_SSE2, HWY_SSSE3, and HWY_SSE4 bits do not need to be - // cleared on x86_64, even if bit 1 of XCR0 is not set, as - // the lower 128 bits of XMM0-XMM15 are guaranteed to be - // preserved across context switches on x86_64 - - // Only clear the AVX2/AVX3 bits on x86_64 if bit 1 of XCR0 is not set - bits &= ~min_avx2; -#else - bits &= ~(HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | min_avx2); + constexpr int64_t min_avx2 = HWY_AVX2 | (HWY_AVX2 - 1); + + if (has_xsave && has_osxsave) { +#if HWY_OS_APPLE + // On macOS, check for AVX3 XSAVE support by checking that we are running on + // macOS 12.2 or later and HasCpuFeature("hw.optional.avx512f") returns true + + // There is a bug in macOS 12.1 or earlier that can cause ZMM16-ZMM31, the + // upper 256 bits of the ZMM registers, and K0-K7 (the AVX512 mask + // registers) to not be properly preserved across a context switch on + // macOS 12.1 or earlier. + + // This bug on macOS 12.1 or earlier on x86_64 CPU's with AVX3 support is + // described at + // https://community.intel.com/t5/Software-Tuning-Performance/MacOS-Darwin-kernel-bug-clobbers-AVX-512-opmask-register-state/m-p/1327259, + // https://github.com/golang/go/issues/49233, and + // https://github.com/simdutf/simdutf/pull/236. + + // In addition to the bug that is there on macOS 12.1 or earlier, bits 5, 6, + // and 7 can be set to 0 on x86_64 CPU's with AVX3 support on macOS until + // the first AVX512 instruction is executed as macOS only preserves + // ZMM16-ZMM31, the upper 256 bits of the ZMM registers, and K0-K7 across a + // context switch on threads that have executed an AVX512 instruction. + + // Checking for AVX3 XSAVE support on macOS using + // HasCpuFeature("hw.optional.avx512f") avoids false negative results + // on x86_64 CPU's that have AVX3 support. + const bool have_avx3_xsave_support = + IsMacOs12_2OrLater() && HasCpuFeature("hw.optional.avx512f"); #endif - } - // YMM - if (!IsBitSet(xcr0, 2)) { + + const uint32_t xcr0 = ReadXCR0(); + constexpr int64_t min_avx3 = HWY_AVX3 | HWY_AVX3_DL | HWY_AVX3_SPR; + // XMM/YMM + if (!IsBitSet(xcr0, 1) || !IsBitSet(xcr0, 2)) { + // Clear the AVX2/AVX3 bits if XMM/YMM XSAVE is not enabled bits &= ~min_avx2; } + +#if !HWY_OS_APPLE + // On OS's other than macOS, check for AVX3 XSAVE support by checking that + // bits 5, 6, and 7 of XCR0 are set. + const bool have_avx3_xsave_support = + IsBitSet(xcr0, 5) && IsBitSet(xcr0, 6) && IsBitSet(xcr0, 7); +#endif + // opmask, ZMM lo/hi - if (!IsBitSet(xcr0, 5) || !IsBitSet(xcr0, 6) || !IsBitSet(xcr0, 7)) { + if (!have_avx3_xsave_support) { bits &= ~min_avx3; } - } // has_osxsave + } else { // !has_xsave || !has_osxsave + // Clear the AVX2/AVX3 bits if the CPU or OS does not support XSAVE + bits &= ~min_avx2; + } // This is mainly to work around the slow Zen4 CompressStore. It's unclear // whether subsequent AMD models will be affected; assume yes. - if ((bits & HWY_AVX3_DL) && IsAMD()) { + if ((bits & HWY_AVX3_DL) && (flags & kGroupAVX3_ZEN4) == kGroupAVX3_ZEN4 && + IsAMD()) { bits |= HWY_AVX3_ZEN4; } @@ -333,18 +453,43 @@ int64_t DetectTargets() { #elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH namespace arm { int64_t DetectTargets() { - int64_t bits = 0; // return value of supported targets. + int64_t bits = 0; // return value of supported targets. + using CapBits = unsigned long; // NOLINT +#if HWY_OS_APPLE + const CapBits hw = 0UL; +#else + // For Android, this has been supported since API 20 (2014). const CapBits hw = getauxval(AT_HWCAP); +#endif (void)hw; #if HWY_ARCH_ARM_A64 bits |= HWY_NEON_WITHOUT_AES; // aarch64 always has NEON and VFPv4.. +#if HWY_OS_APPLE + if (HasCpuFeature("hw.optional.arm.FEAT_AES")) { + bits |= HWY_NEON; + + if (HasCpuFeature("hw.optional.AdvSIMD_HPFPCvt") && + HasCpuFeature("hw.optional.arm.FEAT_DotProd") && + HasCpuFeature("hw.optional.arm.FEAT_BF16")) { + bits |= HWY_NEON_BF16; + } + } +#else // !HWY_OS_APPLE // .. but not necessarily AES, which is required for HWY_NEON. #if defined(HWCAP_AES) if (hw & HWCAP_AES) { bits |= HWY_NEON; + +#if defined(HWCAP_ASIMDHP) && defined(HWCAP_ASIMDDP) && defined(HWCAP2_BF16) + const CapBits hw2 = getauxval(AT_HWCAP2); + const int64_t kGroupF16Dot = HWCAP_ASIMDHP | HWCAP_ASIMDDP; + if ((hw & kGroupF16Dot) == kGroupF16Dot && (hw2 & HWCAP2_BF16)) { + bits |= HWY_NEON_BF16; + } +#endif // HWCAP_ASIMDHP && HWCAP_ASIMDDP && HWCAP2_BF16 } #endif // HWCAP_AES @@ -354,12 +499,17 @@ int64_t DetectTargets() { } #endif -#if defined(HWCAP2_SVE2) && defined(HWCAP2_SVEAES) +#ifndef HWCAP2_SVE2 +#define HWCAP2_SVE2 (1 << 1) +#endif +#ifndef HWCAP2_SVEAES +#define HWCAP2_SVEAES (1 << 2) +#endif const CapBits hw2 = getauxval(AT_HWCAP2); if ((hw2 & HWCAP2_SVE2) && (hw2 & HWCAP2_SVEAES)) { bits |= HWY_SVE2; } -#endif +#endif // HWY_OS_APPLE #else // !HWY_ARCH_ARM_A64 @@ -441,7 +591,91 @@ int64_t DetectTargets() { return bits; } } // namespace ppc -#endif // HWY_ARCH_X86 +#elif HWY_ARCH_S390X && HWY_HAVE_RUNTIME_DISPATCH +namespace s390x { + +#ifndef HWCAP_S390_VX +#define HWCAP_S390_VX 2048 +#endif + +#ifndef HWCAP_S390_VXE +#define HWCAP_S390_VXE 8192 +#endif + +#ifndef HWCAP_S390_VXRS_EXT2 +#define HWCAP_S390_VXRS_EXT2 32768 +#endif + +using CapBits = unsigned long; // NOLINT + +constexpr CapBits kGroupZ14 = HWCAP_S390_VX | HWCAP_S390_VXE; +constexpr CapBits kGroupZ15 = + HWCAP_S390_VX | HWCAP_S390_VXE | HWCAP_S390_VXRS_EXT2; + +int64_t DetectTargets() { + int64_t bits = 0; + +#if defined(AT_HWCAP) + const CapBits hw = getauxval(AT_HWCAP); + + if ((hw & kGroupZ14) == kGroupZ14) { + bits |= HWY_Z14; + } + + if ((hw & kGroupZ15) == kGroupZ15) { + bits |= HWY_Z15; + } +#endif + + return bits; +} +} // namespace s390x +#elif HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH +namespace rvv { + +#ifndef HWCAP_RVV +#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A')) +#endif + +using CapBits = unsigned long; // NOLINT + +int64_t DetectTargets() { + int64_t bits = 0; + + const CapBits hw = getauxval(AT_HWCAP); + + if ((hw & COMPAT_HWCAP_ISA_V) == COMPAT_HWCAP_ISA_V) { + size_t e8m1_vec_len; +#if HWY_ARCH_RISCV_64 + int64_t vtype_reg_val; +#else + int32_t vtype_reg_val; +#endif + + // Check that a vuint8m1_t vector is at least 16 bytes and that tail + // agnostic and mask agnostic mode are supported + asm volatile( + // Avoid compiler error on GCC or Clang if -march=rv64gcv1p0 or + // -march=rv32gcv1p0 option is not specified on the command line + ".option push\n\t" + ".option arch, +v\n\t" + "vsetvli %0, zero, e8, m1, ta, ma\n\t" + "csrr %1, vtype\n\t" + ".option pop" + : "=r"(e8m1_vec_len), "=r"(vtype_reg_val)); + + // The RVV target is supported if the VILL bit of VTYPE (the MSB bit of + // VTYPE) is not set and the length of a vuint8m1_t vector is at least 16 + // bytes + if (vtype_reg_val >= 0 && e8m1_vec_len >= 16) { + bits |= HWY_RVV; + } + } + + return bits; +} +} // namespace rvv +#endif // HWY_ARCH_* // Returns targets supported by the CPU, independently of DisableTargets. // Factored out of SupportedTargets to make its structure more obvious. Note @@ -457,9 +691,13 @@ int64_t DetectTargets() { bits |= arm::DetectTargets(); #elif HWY_ARCH_PPC && HWY_HAVE_RUNTIME_DISPATCH bits |= ppc::DetectTargets(); +#elif HWY_ARCH_S390X && HWY_HAVE_RUNTIME_DISPATCH + bits |= s390x::DetectTargets(); +#elif HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH + bits |= rvv::DetectTargets(); #else - // TODO(janwas): detect support for WASM/RVV. + // TODO(janwas): detect support for WASM. // This file is typically compiled without HWY_IS_TEST, but targets_test has // it set, and will expect all of its HWY_TARGETS (= all attainable) to be // supported. @@ -469,12 +707,11 @@ int64_t DetectTargets() { if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { const uint64_t bits_u = static_cast(bits); const uint64_t enabled = static_cast(HWY_ENABLED_BASELINE); - fprintf(stderr, - "WARNING: CPU supports 0x%08x%08x, software requires 0x%08x%08x\n", - static_cast(bits_u >> 32), - static_cast(bits_u & 0xFFFFFFFF), - static_cast(enabled >> 32), - static_cast(enabled & 0xFFFFFFFF)); + HWY_WARN("CPU supports 0x%08x%08x, software requires 0x%08x%08x\n", + static_cast(bits_u >> 32), + static_cast(bits_u & 0xFFFFFFFF), + static_cast(enabled >> 32), + static_cast(enabled & 0xFFFFFFFF)); } return bits; @@ -482,34 +719,6 @@ int64_t DetectTargets() { } // namespace -HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) - Abort(const char* file, int line, const char* format, ...) { - char buf[800]; - va_list args; - va_start(args, format); - vsnprintf(buf, sizeof(buf), format, args); - va_end(args); - - fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); - -// If compiled with any sanitizer, they can also print a stack trace. -#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN - __sanitizer_print_stack_trace(); -#endif // HWY_IS_* - fflush(stderr); - -// Now terminate the program: -#if HWY_ARCH_RVV - exit(1); // trap/abort just freeze Spike. -#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC - // Facilitates breaking into a debugger, but don't use this in non-debug - // builds because it looks like "illegal instruction", which is misleading. - __builtin_trap(); -#else - abort(); // Compile error without this due to HWY_NORETURN. -#endif -} - HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets) { supported_mask_ = static_cast(~disabled_targets); // This will take effect on the next call to SupportedTargets, which is diff --git a/r/src/vendor/highway/hwy/targets.h b/r/src/vendor/highway/hwy/targets.h index 693e2e80..b3573dd1 100644 --- a/r/src/vendor/highway/hwy/targets.h +++ b/r/src/vendor/highway/hwy/targets.h @@ -29,7 +29,7 @@ #include "hwy/detect_targets.h" #include "hwy/highway_export.h" -#if !HWY_ARCH_RVV && !defined(HWY_NO_LIBCXX) +#if !defined(HWY_NO_LIBCXX) #include #endif @@ -112,6 +112,8 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { return "SVE2"; case HWY_SVE: return "SVE"; + case HWY_NEON_BF16: + return "NEON_BF16"; case HWY_NEON: return "NEON"; case HWY_NEON_WITHOUT_AES: @@ -127,6 +129,13 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { return "PPC10"; #endif +#if HWY_ARCH_S390X + case HWY_Z14: + return "Z14"; + case HWY_Z15: + return "Z15"; +#endif + #if HWY_ARCH_WASM case HWY_WASM: return "WASM"; @@ -134,7 +143,7 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { return "WASM_EMU256"; #endif -#if HWY_ARCH_RVV +#if HWY_ARCH_RISCV case HWY_RVV: return "RVV"; #endif @@ -213,24 +222,24 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { // See HWY_ARCH_X86 above for details. #define HWY_MAX_DYNAMIC_TARGETS 15 #define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_ARM -#define HWY_CHOOSE_TARGET_LIST(func_name) \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - HWY_CHOOSE_SVE2_128(func_name), /* SVE2 128-bit */ \ - HWY_CHOOSE_SVE_256(func_name), /* SVE 256-bit */ \ - HWY_CHOOSE_SVE2(func_name), /* SVE2 */ \ - HWY_CHOOSE_SVE(func_name), /* SVE */ \ - HWY_CHOOSE_NEON(func_name), /* NEON */ \ +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2_128(func_name), /* SVE2 128-bit */ \ + HWY_CHOOSE_SVE_256(func_name), /* SVE 256-bit */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2(func_name), /* SVE2 */ \ + HWY_CHOOSE_SVE(func_name), /* SVE */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON_BF16(func_name), /* NEON + f16/dot/bf16 */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON(func_name), /* NEON */ \ HWY_CHOOSE_NEON_WITHOUT_AES(func_name) /* NEON without AES */ -#elif HWY_ARCH_RVV +#elif HWY_ARCH_RISCV // See HWY_ARCH_X86 above for details. #define HWY_MAX_DYNAMIC_TARGETS 9 #define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_RVV @@ -245,20 +254,20 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { HWY_CHOOSE_RVV(func_name), /* RVV */ \ nullptr /* reserved */ -#elif HWY_ARCH_PPC +#elif HWY_ARCH_PPC || HWY_ARCH_S390X // See HWY_ARCH_X86 above for details. #define HWY_MAX_DYNAMIC_TARGETS 9 #define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_PPC -#define HWY_CHOOSE_TARGET_LIST(func_name) \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - nullptr, /* reserved */ \ - HWY_CHOOSE_PPC10(func_name), /* PPC10 */ \ - HWY_CHOOSE_PPC9(func_name), /* PPC9 */ \ - HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ - nullptr, /* reserved (VSX or AltiVec) */ \ - nullptr /* reserved (VSX or AltiVec) */ +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_PPC10(func_name), /* PPC10 */ \ + HWY_CHOOSE_PPC9(func_name), /* PPC9 */ \ + HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ + HWY_CHOOSE_Z15(func_name), /* Z15 */ \ + HWY_CHOOSE_Z14(func_name) /* Z14 */ #elif HWY_ARCH_WASM // See HWY_ARCH_X86 above for details. @@ -316,8 +325,7 @@ struct ChosenTarget { } private: - // TODO(janwas): remove RVV once is available -#if HWY_ARCH_RVV || defined(HWY_NO_LIBCXX) +#if defined(HWY_NO_LIBCXX) int64_t LoadMask() const { return mask_; } void StoreMask(int64_t mask) { mask_ = mask; } @@ -327,7 +335,7 @@ struct ChosenTarget { void StoreMask(int64_t mask) { mask_.store(mask); } std::atomic mask_{1}; // Initialized to 1 so GetIndex() returns 0. -#endif // HWY_ARCH_RVV +#endif // HWY_ARCH_RISCV }; // For internal use (e.g. by FunctionCache and DisableTargets). diff --git a/r/src/vendor/highway/hwy/timer-inl.h b/r/src/vendor/highway/hwy/timer-inl.h index c286b0a8..9e98e6d0 100644 --- a/r/src/vendor/highway/hwy/timer-inl.h +++ b/r/src/vendor/highway/hwy/timer-inl.h @@ -16,6 +16,8 @@ // High-resolution and high-precision timer // Per-target include guard +// NOTE: this file could/should be a normal header, but user code may reference +// hn::timer, and defining that here requires highway.h. #if defined(HIGHWAY_HWY_TIMER_INL_H_) == defined(HWY_TARGET_TOGGLE) #ifdef HIGHWAY_HWY_TIMER_INL_H_ #undef HIGHWAY_HWY_TIMER_INL_H_ @@ -24,7 +26,6 @@ #endif #include "hwy/highway.h" -#include "hwy/timer.h" #if defined(_WIN32) || defined(_WIN64) #ifndef NOMINMAX @@ -50,6 +51,7 @@ #include #endif +#include #include // clock_gettime HWY_BEFORE_NAMESPACE(); @@ -139,8 +141,8 @@ inline Ticks Start() { // "memory" avoids reordering. rdx = TSC >> 32. // "cc" = flags modified by SHL. : "rdx", "memory", "cc"); -#elif HWY_ARCH_RVV - asm volatile("rdtime %0" : "=r"(t)); +#elif HWY_ARCH_RISCV + asm volatile("fence; rdtime %0" : "=r"(t)); #elif defined(_WIN32) || defined(_WIN64) LARGE_INTEGER counter; (void)QueryPerformanceCounter(&counter); diff --git a/r/src/vendor/highway/hwy/timer.cc b/r/src/vendor/highway/hwy/timer.cc index 28b5892e..4b7f2415 100644 --- a/r/src/vendor/highway/hwy/timer.cc +++ b/r/src/vendor/highway/hwy/timer.cc @@ -17,8 +17,10 @@ #include -#include //NOLINT +#include // NOLINT +#include // NOLINT +#include "hwy/base.h" #include "hwy/robust_statistics.h" #include "hwy/timer-inl.h" @@ -33,7 +35,7 @@ namespace platform { namespace { // Measures the actual current frequency of Ticks. We cannot rely on the nominal -// frequency encoded in x86 BrandString because it is misleading on M1 Rosetta, +// frequency encoded in x86 GetCpuString because it is misleading on M1 Rosetta, // and not reported by AMD. CPUID 0x15 is also not yet widely supported. Also // used on RISC-V and aarch64. HWY_MAYBE_UNUSED double MeasureNominalClockRate() { @@ -59,7 +61,7 @@ HWY_MAYBE_UNUSED double MeasureNominalClockRate() { const double dticks = static_cast(ticks1 - ticks0); std::chrono::duration> dtime = time1 - time0; const double ticks_per_sec = dticks / dtime.count(); - max_ticks_per_sec = std::max(max_ticks_per_sec, ticks_per_sec); + max_ticks_per_sec = HWY_MAX(max_ticks_per_sec, ticks_per_sec); } return max_ticks_per_sec; } @@ -93,27 +95,33 @@ bool HasRDTSCP() { return (abcd[3] & (1u << 27)) != 0; // RDTSCP } -void GetBrandString(char* cpu100) { +#endif // HWY_ARCH_X86 +} // namespace + +HWY_DLLEXPORT bool GetCpuString(char* cpu100) { +#if HWY_ARCH_X86 uint32_t abcd[4]; // Check if brand string is supported (it is on all reasonable Intel/AMD) Cpuid(0x80000000U, 0, abcd); if (abcd[0] < 0x80000004U) { - cpu100[0] = 0; - return; + cpu100[0] = '\0'; + return false; } for (size_t i = 0; i < 3; ++i) { Cpuid(static_cast(0x80000002U + i), 0, abcd); CopyBytes(&abcd[0], cpu100 + i * 16); // not same size } - cpu100[48] = 0; + cpu100[48] = '\0'; + return true; +#else + cpu100[0] = '?'; + cpu100[1] = '\0'; + return false; +#endif } -#endif // HWY_ARCH_X86 - -} // namespace - HWY_DLLEXPORT double Now() { static const double mul = 1.0 / InvariantTicksPerSecond(); return static_cast(timer::Start()) * mul; @@ -122,18 +130,18 @@ HWY_DLLEXPORT double Now() { HWY_DLLEXPORT bool HaveTimerStop(char* cpu100) { #if HWY_ARCH_X86 if (!HasRDTSCP()) { - GetBrandString(cpu100); + (void)GetCpuString(cpu100); return false; } #endif - (void)cpu100; + *cpu100 = '\0'; return true; } HWY_DLLEXPORT double InvariantTicksPerSecond() { #if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) return static_cast(__ppc_get_timebase_freq()); -#elif HWY_ARCH_X86 || HWY_ARCH_RVV || (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) +#elif HWY_ARCH_X86 || HWY_ARCH_RISCV || (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) // We assume the x86 TSC is invariant; it is on all recent Intel/AMD CPUs. static const double freq = MeasureNominalClockRate(); return freq; diff --git a/r/src/vendor/highway/hwy/timer.h b/r/src/vendor/highway/hwy/timer.h index 0ca46e24..7ac0588f 100644 --- a/r/src/vendor/highway/hwy/timer.h +++ b/r/src/vendor/highway/hwy/timer.h @@ -36,7 +36,7 @@ HWY_DLLEXPORT double Now(); // Returns whether it is safe to call timer::Stop without executing an illegal // instruction; if false, fills cpu100 (a pointer to a 100 character buffer) -// with the CPU brand string or an empty string if unknown. +// via GetCpuString(). HWY_DLLEXPORT bool HaveTimerStop(char* cpu100); // Returns tick rate, useful for converting timer::Ticks to seconds. Invariant @@ -49,7 +49,22 @@ HWY_DLLEXPORT double InvariantTicksPerSecond(); // This call is expensive, callers should cache the result. HWY_DLLEXPORT uint64_t TimerResolution(); +// Returns false if no detailed description is available, otherwise fills +// `cpu100` with up to 100 characters (including \0) identifying the CPU model. +HWY_DLLEXPORT bool GetCpuString(char* cpu100); + } // namespace platform + +struct Timestamp { + Timestamp() { t = platform::Now(); } + double t; +}; + +static inline double SecondsSince(const Timestamp& t0) { + const Timestamp t1; + return t1.t - t0.t; +} + } // namespace hwy #endif // HIGHWAY_HWY_TIMER_H_ diff --git a/r/src/vendor/highway/manual-build/build_highway.sh b/r/src/vendor/highway/manual-build/build_highway.sh index 1291c138..1acb5f66 100644 --- a/r/src/vendor/highway/manual-build/build_highway.sh +++ b/r/src/vendor/highway/manual-build/build_highway.sh @@ -27,11 +27,14 @@ HWY_FLAGS=( # Skip the CONTRIB, since we don't need sorting or image libraries HWY_SOURCES=( - hwy/aligned_allocator.cc + hwy/abort.cc + hwy/aligned_allocator.cc hwy/nanobenchmark.cc hwy/per_target.cc hwy/print.cc + hwy/stats.cc hwy/targets.cc + hwy/timer.cc ) HWY_HEADERS=( @@ -40,6 +43,7 @@ HWY_HEADERS=( hwy/contrib/algo/find-inl.h hwy/contrib/algo/transform-inl.h + hwy/abort.h hwy/aligned_allocator.h hwy/base.h hwy/cache_control.h @@ -53,21 +57,26 @@ HWY_HEADERS=( hwy/ops/arm_sve-inl.h hwy/ops/emu128-inl.h hwy/ops/generic_ops-inl.h + hwy/ops/inside-inl.h hwy/ops/ppc_vsx-inl.h hwy/ops/rvv-inl.h hwy/ops/scalar-inl.h hwy/ops/set_macros-inl.h hwy/ops/shared-inl.h hwy/ops/wasm_128-inl.h - hwy/ops/tuple-inl.h + hwy/ops/wasm_256-inl.h hwy/ops/x86_128-inl.h hwy/ops/x86_256-inl.h hwy/ops/x86_512-inl.h hwy/per_target.h hwy/print-inl.h hwy/print.h + hwy/profiler.h + hwy/robust_statistics.h + hwy/stats.h hwy/targets.h hwy/timer-inl.h + hwy/timer.h )