From c9b3669e8f99d80542a03c14aaf3b6b6ed7df15b Mon Sep 17 00:00:00 2001 From: Riwan Date: Tue, 7 Jan 2025 14:17:32 +0100 Subject: [PATCH] feat: add rsqrt with prec step on NEON float 32 --- include/mipp.h | 1 + include/mipp_impl_NEON.hxx | 7 +++++++ include/mipp_object.hxx | 3 +++ 3 files changed, 11 insertions(+) diff --git a/include/mipp.h b/include/mipp.h index d23f770..6688266 100644 --- a/include/mipp.h +++ b/include/mipp.h @@ -793,6 +793,7 @@ template inline reg neg (const reg, const msk) template inline reg abs (const reg) { errorMessage("abs"); exit(-1); } template inline reg sqrt (const reg) { errorMessage("sqrt"); exit(-1); } template inline reg rsqrt (const reg) { errorMessage("rsqrt"); exit(-1); } +template inline reg rsqrt_prec (const reg v) { return rsqrt(v); } template inline reg log (const reg) { errorMessage("log"); exit(-1); } template inline reg exp (const reg) { errorMessage("exp"); exit(-1); } template inline reg sin (const reg) { errorMessage("sin"); exit(-1); } diff --git a/include/mipp_impl_NEON.hxx b/include/mipp_impl_NEON.hxx index 891e3df..2cae37a 100644 --- a/include/mipp_impl_NEON.hxx +++ b/include/mipp_impl_NEON.hxx @@ -2815,6 +2815,13 @@ return vrsqrteq_f32(v1); } + template <> + inline reg rsqrt_prec(const reg v1) { + float32x4_t approx = vrsqrteq_f32(v1); + return vrsqrtsq_f32(v1 * approx, approx) * approx; + } + + // ----------------------------------------------------------------------------------------------------------- sqrt #ifdef __aarch64__ template <> diff --git a/include/mipp_object.hxx b/include/mipp_object.hxx index f8f0f97..cf41028 100644 --- a/include/mipp_object.hxx +++ b/include/mipp_object.hxx @@ -204,6 +204,7 @@ public: inline Reg abs () const { return mipp::abs (r); } inline Reg sqrt () const { return mipp::sqrt (r); } inline Reg rsqrt () const { return mipp::rsqrt (r); } + inline Reg rsqrt_prec () const { return mipp::rsqrt_prec (r); } inline Reg log () const { return mipp::log (r); } inline Reg exp () const { return mipp::exp (r); } inline Reg sin () const { return mipp::sin (r); } @@ -305,6 +306,7 @@ public: inline Reg abs () const { return std::abs(r); } inline Reg sqrt () const { return (T)std::sqrt(r); } inline Reg rsqrt () const { return (T)(1 / std::sqrt(r)); } + inline Reg rsqrt_prec () const { return (T)(1 / std::sqrt(r)); } inline Reg log () const { return (T)std::log(r); } inline Reg exp () const { return (T)std::exp(r); } inline Reg sin () const { return (T)std::sin(r); } @@ -920,6 +922,7 @@ template inline Reg copysign (const Reg v1, c template inline Reg abs (const Reg v) { return v.abs(); } template inline Reg sqrt (const Reg v) { return v.sqrt(); } template inline Reg rsqrt (const Reg v) { return v.rsqrt(); } +template inline Reg rsqrt_prec (const Reg v) { return v.rsqrt_prec(); } template inline Reg log (const Reg v) { return v.log(); } template inline Reg exp (const Reg v) { return v.exp(); } template inline Reg sin (const Reg v) { return v.sin(); }