diff --git a/src/algorithm/bluesteins_algorithm.rs b/src/algorithm/bluesteins_algorithm.rs index 315aa39..297421f 100644 --- a/src/algorithm/bluesteins_algorithm.rs +++ b/src/algorithm/bluesteins_algorithm.rs @@ -3,8 +3,6 @@ use std::sync::Arc; use num_complex::Complex; use num_traits::Zero; -use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 7bae332..88448b2 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -2,8 +2,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils::{self, DoubleBuf, LoadStore}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{DoubleBuf, LoadStore}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -24,31 +23,19 @@ macro_rules! boilerplate_fft_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( + crate::fft_helper::fft_helper_immut( input, output, + &mut [], self.len(), - |in_chunk, out_chunk| { - unsafe { - self.perform_fft_butterfly(DoubleBuf { - input: in_chunk, - output: out_chunk, - }) - }; + 0, + |in_chunk, out_chunk, _| unsafe { + self.perform_fft_butterfly(DoubleBuf { + input: in_chunk, + output: out_chunk, + }) }, ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - } } fn process_outofplace_with_scratch( &self, @@ -56,48 +43,28 @@ macro_rules! boilerplate_fft_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( + crate::fft_helper::fft_helper_outofplace( input, output, + &mut [], self.len(), - |in_chunk, out_chunk| { - unsafe { - self.perform_fft_butterfly(DoubleBuf { - input: in_chunk, - output: out_chunk, - }) - }; + 0, + |in_chunk, out_chunk, _| unsafe { + self.perform_fft_butterfly(DoubleBuf { + input: in_chunk, + output: out_chunk, + }) }, ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { - if buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| unsafe { - self.perform_fft_butterfly(chunk) - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - } + crate::fft_helper::fft_helper_inplace( + buffer, + &mut [], + self.len(), + 0, + |chunk, _| unsafe { self.perform_fft_butterfly(chunk) }, + ); } #[inline(always)] fn get_inplace_scratch_len(&self) -> usize { diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index 95f781f..29b99de 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -1,8 +1,6 @@ use num_complex::Complex; use num_traits::Zero; -use crate::array_utils; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{twiddles, FftDirection}; use crate::{Direction, Fft, FftNum, Length}; @@ -44,6 +42,9 @@ impl Dft { fn outofplace_scratch_len(&self) -> usize { 0 } + fn immut_scratch_len(&self) -> usize { + 0 + } fn perform_fft_immut( &self, @@ -78,7 +79,7 @@ impl Dft { self.perform_fft_immut(signal, spectrum, _scratch); } } -boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len(), |_: &Dft<_>| 0); +boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len()); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/good_thomas_algorithm.rs b/src/algorithm/good_thomas_algorithm.rs index db4b591..2122310 100644 --- a/src/algorithm/good_thomas_algorithm.rs +++ b/src/algorithm/good_thomas_algorithm.rs @@ -7,7 +7,6 @@ use strength_reduce::StrengthReducedUsize; use transpose; use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/mixed_radix.rs b/src/algorithm/mixed_radix.rs index 66405eb..317d4cb 100644 --- a/src/algorithm/mixed_radix.rs +++ b/src/algorithm/mixed_radix.rs @@ -6,7 +6,6 @@ use num_traits::Zero; use transpose; use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index 3a4df60..76ae34b 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -6,8 +6,6 @@ use num_traits::Zero; use primal_check::miller_rabin; use strength_reduce::StrengthReducedUsize; -use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::math_utils; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index 182ab4f..526a0a8 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -4,8 +4,7 @@ use num_complex::Complex; use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9}; use crate::algorithm::radixn::butterfly_3; -use crate::array_utils::{self, bitreversed_transpose, compute_logarithm}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{bitreversed_transpose, compute_logarithm}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; @@ -120,6 +119,9 @@ impl Radix3 { fn outofplace_scratch_len(&self) -> usize { self.outofplace_scratch_len } + fn immut_scratch_len(&self) -> usize { + self.immut_scratch_len + } fn perform_fft_immut( &self, @@ -193,8 +195,7 @@ impl Radix3 { } } } -boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len, |this: &Radix3<_>| this - .immut_scratch_len); +boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index b040cfe..ec18de5 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -6,8 +6,7 @@ use crate::algorithm::butterflies::{ Butterfly1, Butterfly16, Butterfly2, Butterfly32, Butterfly4, Butterfly8, }; use crate::algorithm::radixn::butterfly_4; -use crate::array_utils::{self, bitreversed_transpose}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::bitreversed_transpose; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; @@ -125,8 +124,10 @@ impl Radix4 { fn outofplace_scratch_len(&self) -> usize { self.outofplace_scratch_len } + fn immut_scratch_len(&self) -> usize { + self.immut_scratch_len + } - #[inline] fn perform_fft_immut( &self, input: &[Complex], @@ -201,8 +202,7 @@ impl Radix4 { } } } -boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len, |this: &Radix4<_>| this - .immut_scratch_len); +boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index a85a6bf..7467236 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use num_complex::Complex; -use crate::array_utils::{self, factor_transpose, Load, LoadStore, TransposeFactor}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace, RadixFactor}; +use crate::array_utils::{factor_transpose, Load, LoadStore, TransposeFactor}; +use crate::common::RadixFactor; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; @@ -160,6 +160,9 @@ impl RadixN { fn outofplace_scratch_len(&self) -> usize { self.outofplace_scratch_len } + fn immut_scratch_len(&self) -> usize { + self.immut_scratch_len + } fn perform_fft_immut( &self, @@ -329,8 +332,7 @@ impl RadixN { } } } -boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len, |this: &RadixN<_>| this - .immut_scratch_len); +boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len); #[inline(never)] pub(crate) unsafe fn butterfly_2( diff --git a/src/array_utils.rs b/src/array_utils.rs index 4058987..bfae7df 100644 --- a/src/array_utils.rs +++ b/src/array_utils.rs @@ -143,22 +143,32 @@ mod unit_tests { } } -// Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations -// Returns Ok() if every element ended up in a chunk, Err() if there was a remainder -pub fn iter_chunks_mut( +// A utility that validates the following conditions, then calls chunk_fn() on each chunk of buffer. Passes the entire scratch buffer with each call. +// - buffer1.len() % chunk_size == 0 +// - scratch.len() >= required_scratch +// Returns Ok(()) if the validation passed, Err(()) if there was a problem +// Since this is duplicated into every FFT algorithm we provide, this is tuned to reduce code size as much as possible, with a secondary focus being on ease of implementation +pub fn validate_and_iter( mut buffer: &mut [T], + scratch: &mut [T], chunk_size: usize, - mut chunk_fn: impl FnMut(&mut [T]), + required_scratch: usize, + mut chunk_fn: impl FnMut(&mut [T], &mut [T]), ) -> Result<(), ()> { - // Loop over the buffer, splicing off chunk_size at a time, and calling chunk_fn on each + if scratch.len() < required_scratch { + return Err(()); + } + let scratch = &mut scratch[..required_scratch]; + + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each while buffer.len() >= chunk_size { let (head, tail) = buffer.split_at_mut(chunk_size); buffer = tail; - chunk_fn(head); + chunk_fn(head, scratch); } - // We have a remainder if there's data still in the buffer -- in which case we want to indicate to the caller that there was an unwanted remainder + // We have a remainder if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder if buffer.len() == 0 { Ok(()) } else { @@ -166,80 +176,189 @@ pub fn iter_chunks_mut( } } -// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations -// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder -pub fn iter_chunks_zipped( +// A utility that validates that buffer1.len() % chunk_size == 0, then calls chunk_fn() on each chunk of buffer. +// This version does 2x partial unrolling of the buffer, because most SIMD butterfly algorithms operate that way. +// Returns Ok(()) if the validation passed, Err(()) if there was a problem +pub fn validate_and_iter_unroll2x( + mut buffer: &mut [T], + chunk_size: usize, + mut chunk2x_fn: impl FnMut(&mut [T]), + mut chunk_fn: impl FnMut(&mut [T]), +) -> Result<(), ()> { + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer.len() >= chunk_size * 2 { + let (head, tail) = buffer.split_at_mut(chunk_size * 2); + buffer = tail; + + chunk2x_fn(head); + } + + if buffer.len() == chunk_size { + chunk_fn(buffer); + Ok(()) + } else if buffer.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// A utility that validates the following conditions, then calls chunk_fn() on each chunk of buffer1 and buffer 2 zipped together. Passes the entire scratch buffer with each call. +// - buffer1.len() == buffer2.len() +// - buffer1.len() % chunk_size == 0 +// - scratch.len() >= required_scratch +// Returns Ok(()) if the validation passed, Err(()) if there was a problem +// Since this is duplicated into every FFT algorithm we provide, this is tuned to reduce code size as much as possible, with a secondary focus being on ease of implementation +pub fn validate_and_zip( mut buffer1: &[T], mut buffer2: &mut [T], + scratch: &mut [T], chunk_size: usize, - mut chunk_fn: impl FnMut(&[T], &mut [T]), + required_scratch: usize, + mut chunk_fn: impl FnMut(&[T], &mut [T], &mut [T]), ) -> Result<(), ()> { - // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size - let uneven = match buffer1.len().cmp(&buffer2.len()) { - std::cmp::Ordering::Less => { - buffer2 = &mut buffer2[..buffer1.len()]; - true - } - std::cmp::Ordering::Equal => false, - std::cmp::Ordering::Greater => { - buffer1 = &buffer1[..buffer2.len()]; - true - } - }; + if scratch.len() < required_scratch { + return Err(()); + } + let scratch = &mut scratch[..required_scratch]; + + if buffer1.len() != buffer2.len() { + return Err(()); + } // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each - while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size { + while buffer1.len() >= chunk_size { let (head1, tail1) = buffer1.split_at(chunk_size); buffer1 = tail1; let (head2, tail2) = buffer2.split_at_mut(chunk_size); buffer2 = tail2; - chunk_fn(head1, head2); + chunk_fn(head1, head2, scratch); } // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder - if !uneven && buffer1.len() == 0 { + if buffer1.len() == 0 { Ok(()) } else { Err(()) } } -// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations -// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder -pub fn iter_chunks_zipped_mut( +// A utility that validates the following conditions, then calls chunk_fn() on each chunk of buffer1 and buffer 2 zipped together. Passes the entire scratch buffer with each call. +// - buffer1.len() == buffer2.len() +// - buffer1.len() % chunk_size == 0 +// Returns Ok(()) if the validation passed, Err(()) if there was a problem +// This version does 2x partial unrolling of the buffer, because most SIMD butterfly algorithms operate that way. +// Since this is duplicated into every FFT algorithm we provide, this is tuned to reduce code size as much as possible, with a secondary focus being on ease of implementation +pub fn validate_and_zip_unroll2x( + mut buffer1: &[T], + mut buffer2: &mut [T], + chunk_size: usize, + mut chunk2x_fn: impl FnMut(&[T], &mut [T]), + mut chunk_fn: impl FnMut(&[T], &mut [T]), +) -> Result<(), ()> { + if buffer1.len() != buffer2.len() { + return Err(()); + } + + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer1.len() >= chunk_size * 2 { + let (head1, tail1) = buffer1.split_at(chunk_size * 2); + buffer1 = tail1; + + let (head2, tail2) = buffer2.split_at_mut(chunk_size * 2); + buffer2 = tail2; + + chunk2x_fn(head1, head2); + } + + // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder + if buffer1.len() == chunk_size { + chunk_fn(buffer1, buffer2); + Ok(()) + } else if buffer1.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// A utility that validates the following conditions, then calls chunk_fn() on each chunk of buffer1 and buffer 2 zipped together. Passes the entire scratch buffer with each call. +// - buffer1.len() == buffer2.len() +// - buffer1.len() % chunk_size == 0 +// - scratch.len() >= required_scratch +// Returns Ok(()) if the validation passed, Err(()) if there was a problem +// Since this is duplicated into every FFT algorithm we provide, this is tuned to reduce code size as much as possible, with a secondary focus being on ease of implementation +pub fn validate_and_zip_mut( mut buffer1: &mut [T], mut buffer2: &mut [T], + scratch: &mut [T], chunk_size: usize, - mut chunk_fn: impl FnMut(&mut [T], &mut [T]), + required_scratch: usize, + mut chunk_fn: impl FnMut(&mut [T], &mut [T], &mut [T]), ) -> Result<(), ()> { - // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size - let uneven = match buffer1.len().cmp(&buffer2.len()) { - std::cmp::Ordering::Less => { - buffer2 = &mut buffer2[..buffer1.len()]; - true - } - std::cmp::Ordering::Equal => false, - std::cmp::Ordering::Greater => { - buffer1 = &mut buffer1[..buffer2.len()]; - true - } - }; + if scratch.len() < required_scratch { + return Err(()); + } + let scratch = &mut scratch[..required_scratch]; + + if buffer1.len() != buffer2.len() { + return Err(()); + } // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each - while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size { + while buffer1.len() >= chunk_size { let (head1, tail1) = buffer1.split_at_mut(chunk_size); buffer1 = tail1; let (head2, tail2) = buffer2.split_at_mut(chunk_size); buffer2 = tail2; - chunk_fn(head1, head2); + chunk_fn(head1, head2, scratch); } // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder - if !uneven && buffer1.len() == 0 { + if buffer1.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// A utility that validates the following conditions, then calls chunk_fn() on each chunk of buffer1 and buffer 2 zipped together. Passes the entire scratch buffer with each call. +// - buffer1.len() == buffer2.len() +// - buffer1.len() % chunk_size == 0 +// Returns Ok(()) if the validation passed, Err(()) if there was a problem +// This version does 2x partial unrolling of the buffer, because most SIMD butterfly algorithms operate that way. +// Since this is duplicated into every FFT algorithm we provide, this is tuned to reduce code size as much as possible, with a secondary focus being on ease of implementation +pub fn validate_and_zip_mut_unroll2x( + mut buffer1: &mut [T], + mut buffer2: &mut [T], + chunk_size: usize, + mut chunk2x_fn: impl FnMut(&mut [T], &mut [T]), + mut chunk_fn: impl FnMut(&mut [T], &mut [T]), +) -> Result<(), ()> { + if buffer1.len() != buffer2.len() { + return Err(()); + } + + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer1.len() >= chunk_size * 2 { + let (head1, tail1) = buffer1.split_at_mut(chunk_size * 2); + buffer1 = tail1; + + let (head2, tail2) = buffer2.split_at_mut(chunk_size * 2); + buffer2 = tail2; + + chunk2x_fn(head1, head2); + } + + // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder + if buffer1.len() == chunk_size { + chunk_fn(buffer1, buffer2); + Ok(()) + } else if buffer1.len() == 0 { Ok(()) } else { Err(()) diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index e0c8539..7f70308 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -4,10 +4,7 @@ use std::mem::MaybeUninit; use num_complex::Complex; -use crate::array_utils; use crate::array_utils::DoubleBuf; -use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -43,33 +40,17 @@ macro_rules! boilerplate_fft_simd_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( - input, - output, - self.len(), - |in_chunk, out_chunk| { - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_f32(DoubleBuf { - input: input_slice, - output: output_slice, - }); - } - }, - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_f32(DoubleBuf { input, output }), + ); } } fn process_outofplace_with_scratch( @@ -78,53 +59,29 @@ macro_rules! boilerplate_fft_simd_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( - input, - output, - self.len(), - |in_chunk, out_chunk| { - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_f32(DoubleBuf { - input: input_slice, - output: output_slice, - }); - } - }, - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_f32(DoubleBuf { input, output }), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { - if buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - self.perform_fft_f32(workaround_transmute_mut::<_, Complex>(chunk)); - } - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::avx_fft_helper_inplace( + simd_buffer, + &mut [], + self.len(), + 0, + |chunk, _| self.perform_fft_f32(chunk), + ) } } #[inline(always)] @@ -204,15 +161,6 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available unsafe { self.row_butterflies(output) }; } - - #[inline] - fn perform_fft_out_of_place( - &self, - input: &mut [Complex], - output: &mut [Complex], - ) { - self.perform_fft_immut(input, output); - } } impl Fft for $struct_name { fn process_immutable_with_scratch( @@ -221,28 +169,17 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_input: &[Complex] = - unsafe { array_utils::workaround_transmute(input) }; - let transmuted_output: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( - transmuted_input, - transmuted_output, - self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk), - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_immut(input, output), + ); } } fn process_outofplace_with_scratch( @@ -251,53 +188,30 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_input: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(input) }; - let transmuted_output: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped_mut( - transmuted_input, - transmuted_output, - self.len(), - |in_chunk, out_chunk| self.perform_fft_out_of_place(in_chunk, out_chunk), - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_immut(input, output), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - let required_scratch = self.len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), self.len(), scratch.len()); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_buffer: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(buffer) }; - let transmuted_scratch: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, transmuted_scratch) - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), self.len(), scratch.len()); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + let simd_scratch = crate::array_utils::workaround_transmute_mut(scratch); + super::avx_fft_helper_inplace( + simd_buffer, + simd_scratch, + self.len(), + self.len(), + |chunk, scratch| self.perform_fft_inplace(chunk, scratch), + ) } } #[inline(always)] diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index e6bc7d0..1d00f15 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -4,10 +4,7 @@ use std::mem::MaybeUninit; use num_complex::Complex; -use crate::array_utils; use crate::array_utils::DoubleBuf; -use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -41,33 +38,17 @@ macro_rules! boilerplate_fft_simd_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( - input, - output, - self.len(), - |in_chunk, out_chunk| { - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_f64(DoubleBuf { - input: input_slice, - output: output_slice, - }); - } - }, - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_f64(DoubleBuf { input, output }), + ); } } fn process_outofplace_with_scratch( @@ -76,53 +57,29 @@ macro_rules! boilerplate_fft_simd_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( - input, - output, - self.len(), - |in_chunk, out_chunk| { - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_f64(DoubleBuf { - input: input_slice, - output: output_slice, - }); - } - }, - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_f64(DoubleBuf { input, output }), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { - if buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - self.perform_fft_f64(workaround_transmute_mut::<_, Complex>(chunk)); - } - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::avx_fft_helper_inplace( + simd_buffer, + &mut [], + self.len(), + 0, + |chunk, _| self.perform_fft_f64(chunk), + ) } } #[inline(always)] @@ -204,15 +161,6 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available unsafe { self.row_butterflies(output) }; } - - #[inline] - fn perform_fft_out_of_place( - &self, - input: &[Complex], - output: &mut [Complex], - ) { - self.perform_fft_immut(input, output); - } } impl Fft for $struct_name { fn process_immutable_with_scratch( @@ -221,28 +169,17 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_input: &[Complex] = - unsafe { array_utils::workaround_transmute(input) }; - let transmuted_output: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( - transmuted_input, - transmuted_output, - self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk), - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_immut(input, output), + ); } } fn process_outofplace_with_scratch( @@ -251,53 +188,30 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_input: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(input) }; - let transmuted_output: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped_mut( - transmuted_input, - transmuted_output, - self.len(), - |in_chunk, out_chunk| self.perform_fft_out_of_place(in_chunk, out_chunk), - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::avx_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_immut(input, output), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - let required_scratch = self.len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), self.len(), scratch.len()); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_buffer: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(buffer) }; - let transmuted_scratch: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, transmuted_scratch) - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), self.len(), scratch.len()); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + let simd_scratch = crate::array_utils::workaround_transmute_mut(scratch); + super::avx_fft_helper_inplace( + simd_buffer, + simd_scratch, + self.len(), + self.len(), + |chunk, scratch| self.perform_fft_inplace(chunk, scratch), + ) } } #[inline(always)] diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index d85020b..28430b3 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -5,7 +5,6 @@ use num_complex::Complex; use num_integer::div_ceil; use num_traits::Zero; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{array_utils, twiddles, FftDirection}; use crate::{Direction, Fft, FftNum, Length}; diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index a64476c..834db4b 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -5,7 +5,6 @@ use num_complex::Complex; use num_integer::div_ceil; use crate::array_utils; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{Direction, Fft, FftDirection, FftNum, Length}; use super::{AvxNum, CommonSimdData}; @@ -36,8 +35,12 @@ macro_rules! boilerplate_mixedradix { } } - #[inline] - fn perform_fft_inplace(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_fft_inplace( + &self, + buffer: &mut [Complex], + scratch: &mut [Complex], + ) { // Perform the column FFTs // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available unsafe { @@ -69,17 +72,16 @@ macro_rules! boilerplate_mixedradix { } } - #[inline] - fn perform_fft_immut( + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_fft_immut( &self, input: &[Complex], output: &mut [Complex], scratch: &mut [Complex], ) { // Perform the column FFTs - // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available let (scratch, inner_scratch) = scratch.split_at_mut(input.len()); - unsafe { + { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); let transmuted_output: &mut [Complex] = @@ -94,8 +96,7 @@ macro_rules! boilerplate_mixedradix { .process_with_scratch(scratch, inner_scratch); // Transpose - // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available - unsafe { + { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary let transmuted_input: &mut [Complex] = array_utils::workaround_transmute_mut(scratch); @@ -106,20 +107,18 @@ macro_rules! boilerplate_mixedradix { } } - #[inline] - fn perform_fft_out_of_place( + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_fft_out_of_place( &self, input: &mut [Complex], output: &mut [Complex], scratch: &mut [Complex], ) { // Perform the column FFTs - // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available - unsafe { + { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary let transmuted_input: &mut [Complex] = array_utils::workaround_transmute_mut(input); - self.perform_column_butterflies(transmuted_input); } @@ -134,8 +133,7 @@ macro_rules! boilerplate_mixedradix { .process_with_scratch(input, inner_scratch); // Transpose - // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available - unsafe { + { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary let transmuted_input: &mut [Complex] = array_utils::workaround_transmute_mut(input); diff --git a/src/avx/avx_raders.rs b/src/avx/avx_raders.rs index 439395e..b5ebd6a 100644 --- a/src/avx/avx_raders.rs +++ b/src/avx/avx_raders.rs @@ -8,7 +8,6 @@ use num_traits::Zero; use primal_check::miller_rabin; use strength_reduce::StrengthReducedUsize; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{array_utils, FftDirection}; use crate::{math_utils, twiddles}; use crate::{Direction, Fft, FftNum, Length}; diff --git a/src/avx/avx_vector.rs b/src/avx/avx_vector.rs index 97807f2..d53035d 100644 --- a/src/avx/avx_vector.rs +++ b/src/avx/avx_vector.rs @@ -14,7 +14,6 @@ use super::AvxNum; /// /// The goal of this trait is to reduce code duplication by letting code be generic over the vector type pub trait AvxVector: Copy + Debug + Send + Sync { - const SCALAR_PER_VECTOR: usize; const COMPLEX_PER_VECTOR: usize; // useful constants @@ -729,7 +728,6 @@ impl Rotation90 { } impl AvxVector for __m256 { - const SCALAR_PER_VECTOR: usize = 8; const COMPLEX_PER_VECTOR: usize = 4; #[inline(always)] @@ -1102,7 +1100,6 @@ impl AvxVector256 for __m256 { } impl AvxVector for __m128 { - const SCALAR_PER_VECTOR: usize = 4; const COMPLEX_PER_VECTOR: usize = 2; #[inline(always)] @@ -1375,7 +1372,6 @@ impl AvxVector128 for __m128 { } impl AvxVector for __m256d { - const SCALAR_PER_VECTOR: usize = 4; const COMPLEX_PER_VECTOR: usize = 2; #[inline(always)] @@ -1684,7 +1680,6 @@ impl AvxVector256 for __m256d { } impl AvxVector for __m128d { - const SCALAR_PER_VECTOR: usize = 2; const COMPLEX_PER_VECTOR: usize = 1; #[inline(always)] @@ -1900,9 +1895,6 @@ pub trait AvxArrayMut: AvxArray + DerefMut { index: usize, ); unsafe fn store_partial3_complex(&mut self, data: T::VectorType, index: usize); - - // some avx operations need bespoke one-off things that don't fit into the methods above, so we should provide an escape hatch for them - fn output_ptr(&mut self) -> *mut Complex; } impl AvxArray for &[Complex] { @@ -2030,10 +2022,6 @@ impl AvxArrayMut for &mut [Complex] { debug_assert!(self.len() >= index + 3); T::VectorType::store_partial3_complex(self.as_mut_ptr().add(index), data); } - #[inline(always)] - fn output_ptr(&mut self) -> *mut Complex { - self.as_mut_ptr() - } } impl<'a, T: AvxNum> AvxArrayMut for DoubleBuf<'a, T> where @@ -2064,10 +2052,6 @@ where unsafe fn store_partial3_complex(&mut self, data: T::VectorType, index: usize) { self.output.store_partial3_complex(data, index); } - #[inline(always)] - fn output_ptr(&mut self) -> *mut Complex { - self.output.output_ptr() - } } // A custom butterfly-16 function that calls a lambda to load/store data instead of taking an array diff --git a/src/avx/mod.rs b/src/avx/mod.rs index 4aec9df..b23d0b7 100644 --- a/src/avx/mod.rs +++ b/src/avx/mod.rs @@ -1,3 +1,4 @@ +use crate::fft_helper::{fft_helper_immut, fft_helper_inplace, fft_helper_outofplace}; use crate::{Fft, FftDirection, FftNum}; use std::arch::x86_64::{__m256, __m256d}; use std::sync::Arc; @@ -37,39 +38,16 @@ macro_rules! boilerplate_avx_fft { output: &mut [Complex], scratch: &mut [Complex], ) { - let required_scratch = self.get_immutable_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut( + unsafe { + super::avx_fft_helper_immut( + input, + output, + scratch, self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), - ); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( - input, - output, - self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut( - self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), + self.get_immutable_scratch_len(), + |in_chunk, out_chunk, scratch| { + self.perform_fft_immut(in_chunk, out_chunk, scratch) + }, ) } } @@ -80,70 +58,27 @@ macro_rules! boilerplate_avx_fft { output: &mut [Complex], scratch: &mut [Complex], ) { - let required_scratch = self.get_outofplace_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped_mut( - input, - output, - self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) - }, - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + unsafe { + super::avx_fft_helper_outofplace( + input, + output, + scratch, self.len(), - input.len(), - output.len(), self.get_outofplace_scratch_len(), - scratch.len(), + |in_chunk, out_chunk, scratch| { + self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) + }, ) } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( + unsafe { + super::avx_fft_helper_inplace( + buffer, + scratch, self.len(), - buffer.len(), self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, scratch) - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), + |chunk, scratch| self.perform_fft_inplace(chunk, scratch), ) } } @@ -184,44 +119,17 @@ macro_rules! boilerplate_avx_fft_commondata { output: &mut [Complex], scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.common_data.immut_scratch_len; - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut( + unsafe { + super::avx_fft_helper_immut( + input, + output, + scratch, self.len(), - input.len(), - output.len(), self.get_immutable_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( - input, - output, - self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut( - self.len(), - input.len(), - output.len(), - self.get_immutable_scratch_len(), - scratch.len(), - ); + |in_chunk, out_chunk, scratch| { + self.perform_fft_immut(in_chunk, out_chunk, scratch) + }, + ) } } fn process_outofplace_with_scratch( @@ -230,79 +138,28 @@ macro_rules! boilerplate_avx_fft_commondata { output: &mut [Complex], scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_outofplace_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + unsafe { + super::avx_fft_helper_outofplace( + input, + output, + scratch, self.len(), - input.len(), - output.len(), self.get_outofplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped_mut( - input, - output, - self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) - }, - ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); + |in_chunk, out_chunk, scratch| { + self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) + }, + ) } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, scratch) - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( + unsafe { + super::avx_fft_helper_inplace( + buffer, + scratch, self.len(), - buffer.len(), self.get_inplace_scratch_len(), - scratch.len(), - ); + |chunk, scratch| self.perform_fft_inplace(chunk, scratch), + ) } } #[inline(always)] @@ -333,6 +190,61 @@ macro_rules! boilerplate_avx_fft_commondata { }; } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the AVX and FMA target features, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "avx", enable = "fma")] +pub unsafe fn avx_fft_helper_immut( + input: &[T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&[T], &mut [T], &mut [T]), +) { + fft_helper_immut( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the AVX and FMA target features, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "avx", enable = "fma")] +pub unsafe fn avx_fft_helper_outofplace( + input: &mut [T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T], &mut [T]), +) { + fft_helper_outofplace( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the AVX and FMA target features, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "avx", enable = "fma")] +pub unsafe fn avx_fft_helper_inplace( + buffer: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_inplace(buffer, scratch, chunk_size, required_scratch, chunk_fn) +} + #[macro_use] mod avx_vector; diff --git a/src/common.rs b/src/common.rs index 1670ebf..2dd1db1 100644 --- a/src/common.rs +++ b/src/common.rs @@ -104,7 +104,7 @@ pub fn fft_error_immut( } macro_rules! boilerplate_fft_oop { - ($struct_name:ident, $len_fn:expr, $immut_scratch_len:expr) => { + ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { fn process_immutable_with_scratch( &self, @@ -112,44 +112,16 @@ macro_rules! boilerplate_fft_oop { output: &mut [Complex], scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_immutable_scratch_len(); - if input.len() < self.len() - || output.len() != input.len() - || scratch.len() < required_scratch - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut( - self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), - ); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped( + crate::fft_helper::fft_helper_immut( input, output, + scratch, self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + self.get_immutable_scratch_len(), + |in_chunk, out_chunk, scratch| { + self.perform_fft_immut(in_chunk, out_chunk, scratch) + }, ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut( - self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), - ); - } } fn process_outofplace_with_scratch( &self, @@ -157,74 +129,29 @@ macro_rules! boilerplate_fft_oop { output: &mut [Complex], scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_outofplace_scratch_len(); - if input.len() < self.len() - || output.len() != input.len() - || scratch.len() < required_scratch - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), - ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let result = array_utils::iter_chunks_zipped_mut( + crate::fft_helper::fft_helper_outofplace( input, output, + scratch, self.len(), - |in_chunk, out_chunk| { + self.get_outofplace_scratch_len(), + |in_chunk, out_chunk, scratch| { self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) }, ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let (scratch, extra_scratch) = scratch.split_at_mut(self.len()); - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_out_of_place(chunk, scratch, extra_scratch); - chunk.copy_from_slice(scratch); - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - } + crate::fft_helper::fft_helper_inplace( + buffer, + scratch, + self.len(), + self.get_inplace_scratch_len(), + |chunk, scratch| { + let (self_scratch, inner_scratch) = scratch.split_at_mut(self.len()); + self.perform_fft_out_of_place(chunk, self_scratch, inner_scratch); + chunk.copy_from_slice(self_scratch); + }, + ); } #[inline(always)] fn get_inplace_scratch_len(&self) -> usize { @@ -236,7 +163,7 @@ macro_rules! boilerplate_fft_oop { } #[inline(always)] fn get_immutable_scratch_len(&self) -> usize { - $immut_scratch_len(self) + self.immut_scratch_len() } } impl Length for $struct_name { @@ -263,39 +190,16 @@ macro_rules! boilerplate_fft { output: &mut [Complex], scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - let required_scratch = self.get_immutable_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - crate::common::fft_error_immut( - self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), - ); - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + crate::fft_helper::fft_helper_immut( input, output, + scratch, self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + self.get_immutable_scratch_len(), + |in_chunk, out_chunk, scratch| { + self.perform_fft_immut(in_chunk, out_chunk, scratch) + }, ); - if result.is_err() { - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - required_scratch, - scratch.len(), - ); - } } fn process_outofplace_with_scratch( @@ -304,80 +208,27 @@ macro_rules! boilerplate_fft { output: &mut [Complex], scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_outofplace_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped_mut( + crate::fft_helper::fft_helper_outofplace( input, output, + scratch, self.len(), - |in_chunk, out_chunk| { + self.get_outofplace_scratch_len(), + |in_chunk, out_chunk, scratch| { self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) }, ); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, scratch) - }); - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - } + crate::fft_helper::fft_helper_inplace( + buffer, + scratch, + self.len(), + self.get_inplace_scratch_len(), + |chunk, scratch| { + self.perform_fft_inplace(chunk, scratch); + }, + ); } #[inline(always)] fn get_inplace_scratch_len(&self) -> usize { diff --git a/src/fft_helper.rs b/src/fft_helper.rs new file mode 100644 index 0000000..2eda8c0 --- /dev/null +++ b/src/fft_helper.rs @@ -0,0 +1,175 @@ +use crate::{ + array_utils, + common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}, +}; + +// A utility that validates the provided FFT parameters, executes the FFT if validation succeeds, or panics with a hopefully helpful error message if validation fails +// Since this implementation is basically always the same across all algorithms, this helper keeps us from having to duplicate it +#[inline(always)] +pub fn fft_helper_inplace( + buffer: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + if chunk_size == 0 { + return; + } + + let result = + array_utils::validate_and_iter(buffer, scratch, chunk_size, required_scratch, chunk_fn); + + if result.is_err() { + // We want to trigger a panic, because the passed parameters failed validation in some way. + // But we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_inplace(chunk_size, buffer.len(), required_scratch, scratch.len()); + } +} + +// A utility that validates the provided FFT parameters, executes the FFT if validation succeeds, or panics with a hopefully helpful error message if validation fails +// Since this implementation is basically always the same across all algorithms, this helper keeps us from having to duplicate it +#[allow(dead_code)] +#[inline(always)] +pub fn fft_helper_inplace_unroll2x( + buffer: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T]), + chunk_fn: impl FnMut(&mut [T]), +) { + if chunk_size == 0 { + return; + } + + let result = array_utils::validate_and_iter_unroll2x(buffer, chunk_size, chunk2x_fn, chunk_fn); + + if result.is_err() { + // We want to trigger a panic, because the passed parameters failed validation in some way. + // But we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_inplace(chunk_size, buffer.len(), 0, 0); + } +} + +// A utility that validates the provided FFT parameters, executes the FFT if validation succeeds, or panics with a hopefully helpful error message if validation fails +// Since this implementation is basically always the same across all algorithms, this helper keeps us from having to duplicate it +#[inline(always)] +pub fn fft_helper_immut( + input: &[T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&[T], &mut [T], &mut [T]), +) { + if chunk_size == 0 { + return; + } + + let result = array_utils::validate_and_zip( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ); + + if result.is_err() { + // We want to trigger a panic, because the passed parameters failed validation in some way. + // But we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + chunk_size, + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } +} + +// A utility that validates the provided FFT parameters, executes the FFT if validation succeeds, or panics with a hopefully helpful error message if validation fails +// Since this implementation is basically always the same across all algorithms, this helper keeps us from having to duplicate it +#[allow(dead_code)] +#[inline(always)] +pub fn fft_helper_immut_unroll2x( + input: &[T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&[T], &mut [T]), + chunk_fn: impl FnMut(&[T], &mut [T]), +) { + if chunk_size == 0 { + return; + } + + let result = + array_utils::validate_and_zip_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn); + + if result.is_err() { + // We want to trigger a panic, because the passed parameters failed validation in some way. + // But we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(chunk_size, input.len(), output.len(), 0, 0); + } +} + +// A utility that validates the provided FFT parameters, executes the FFT if validation succeeds, or panics with a hopefully helpful error message if validation fails +// Since this implementation is basically always the same across all algorithms, this helper keeps us from having to duplicate it +#[inline(always)] +pub fn fft_helper_outofplace( + input: &mut [T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T], &mut [T]), +) { + if chunk_size == 0 { + return; + } + + let result = array_utils::validate_and_zip_mut( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ); + + if result.is_err() { + // We want to trigger a panic, because the passed parameters failed validation in some way. + // But we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + chunk_size, + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } +} + +// A utility that validates the provided FFT parameters, executes the FFT if validation succeeds, or panics with a hopefully helpful error message if validation fails +// Since this implementation is basically always the same across all algorithms, this helper keeps us from having to duplicate it +#[allow(dead_code)] +#[inline(always)] +pub fn fft_helper_outofplace_unroll2x( + input: &mut [T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T], &mut [T]), + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + if chunk_size == 0 { + return; + } + + let result = + array_utils::validate_and_zip_mut_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn); + + if result.is_err() { + // We want to trigger a panic, because the passed parameters failed validation in some way. + // But we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(chunk_size, input.len(), output.len(), 0, 0); + } +} diff --git a/src/lib.rs b/src/lib.rs index d88f330..c2154db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,6 +125,7 @@ mod common; pub mod algorithm; mod array_utils; mod fft_cache; +mod fft_helper; mod math_utils; mod plan; mod twiddles; diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index 571e8da..b1f9108 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -3,10 +3,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -25,107 +22,88 @@ unsafe fn pack_64(a: Complex) -> float64x2_t { #[allow(unused)] macro_rules! boilerplate_fft_neon_f32_butterfly { - ($struct_name:ident) => { - impl $struct_name { - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); - } - - pub(crate) unsafe fn perform_parallel_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_parallel_fft_contiguous(workaround_transmute_mut::<_, Complex>( - buffer, - )); - } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - pub(crate) unsafe fn perform_fft_butterfly_multi( + ($struct_name:ident, $len:expr, $direction_fn:expr) => { + impl Fft for $struct_name { + fn process_immutable_with_scratch( &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - let len = buffer.len(); - let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { - self.perform_parallel_fft_butterfly(chunk) - }); - if alldone.is_err() && buffer.len() >= self.len() { - self.perform_fft_butterfly(&mut buffer[len - self.len()..]); + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::neon_common::neon_fft_helper_immut_unroll2x( + simd_input, + simd_output, + self.len(), + |input, output| { + self.perform_parallel_fft_contiguous(DoubleBuf { input, output }) + }, + |input, output| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } - Ok(()) } - - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( + fn process_outofplace_with_scratch( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], - ) -> Result<(), ()> { - let len = input.len(); - let alldone = array_utils::iter_chunks_zipped( - input, - output, - 2 * self.len(), - |in_chunk, out_chunk| { - let input_slice = crate::array_utils::workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_parallel_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }, - ); - if alldone.is_err() && input.len() >= self.len() { - let input_slice = crate::array_utils::workaround_transmute(input); - let output_slice = workaround_transmute_mut(output); - self.perform_fft_contiguous(DoubleBuf { - input: &input_slice[len - self.len()..], - output: &mut output_slice[len - self.len()..], - }) + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::neon_common::neon_fft_helper_outofplace_unroll2x( + simd_input, + simd_output, + self.len(), + |input, output| { + self.perform_parallel_fft_contiguous(DoubleBuf { input, output }) + }, + |input, output| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } - Ok(()) } - } - }; -} - -macro_rules! boilerplate_fft_neon_f64_butterfly { - ($struct_name:ident) => { - impl $struct_name { - // Do a single fft - //#[target_feature(enable = "neon")] - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); + fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::neon_common::neon_fft_helper_inplace_unroll2x( + simd_buffer, + self.len(), + |chunk| self.perform_parallel_fft_contiguous(chunk), + |chunk| self.perform_fft_contiguous(chunk), + ) + } } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - //#[target_feature(enable = "neon")] - pub(crate) unsafe fn perform_fft_butterfly_multi( - &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_butterfly(chunk) - }) + #[inline(always)] + fn get_inplace_scratch_len(&self) -> usize { + 0 } - - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( - &self, - input: &[Complex], - output: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = crate::array_utils::workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }) + #[inline(always)] + fn get_outofplace_scratch_len(&self) -> usize { + 0 + } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } + } + impl Length for $struct_name { + #[inline(always)] + fn len(&self) -> usize { + $len + } + } + impl Direction for $struct_name { + #[inline(always)] + fn fft_direction(&self) -> FftDirection { + $direction_fn(self) } } }; } -#[allow(unused)] -macro_rules! boilerplate_fft_neon_common_butterfly { +macro_rules! boilerplate_fft_neon_f64_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { fn process_immutable_with_scratch( @@ -134,17 +112,17 @@ macro_rules! boilerplate_fft_neon_common_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::neon_common::neon_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } } fn process_outofplace_with_scratch( @@ -153,32 +131,29 @@ macro_rules! boilerplate_fft_neon_common_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::neon_common::neon_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { - if buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let result = unsafe { self.perform_fft_butterfly_multi(buffer) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::neon_common::neon_fft_helper_inplace( + simd_buffer, + &mut [], + self.len(), + 0, + |chunk, _| self.perform_fft_contiguous(chunk), + ) } } #[inline(always)] @@ -221,8 +196,7 @@ pub struct NeonF32Butterfly1 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly1); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly1, 1, |this: &NeonF32Butterfly1<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly1, 1, |this: &NeonF32Butterfly1<_>| this .direction); impl NeonF32Butterfly1 { #[inline(always)] @@ -261,8 +235,7 @@ pub struct NeonF64Butterfly1 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly1); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly1, 1, |this: &NeonF64Butterfly1<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly1, 1, |this: &NeonF64Butterfly1<_>| this .direction); impl NeonF64Butterfly1 { #[inline(always)] @@ -292,8 +265,7 @@ pub struct NeonF32Butterfly2 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly2); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly2, 2, |this: &NeonF32Butterfly2<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly2, 2, |this: &NeonF32Butterfly2<_>| this .direction); impl NeonF32Butterfly2 { #[inline(always)] @@ -389,8 +361,7 @@ pub struct NeonF64Butterfly2 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly2); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly2, 2, |this: &NeonF64Butterfly2<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly2, 2, |this: &NeonF64Butterfly2<_>| this .direction); impl NeonF64Butterfly2 { #[inline(always)] @@ -447,8 +418,7 @@ pub struct NeonF32Butterfly3 { twiddle2im: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly3); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly3, 3, |this: &NeonF32Butterfly3<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly3, 3, |this: &NeonF32Butterfly3<_>| this .direction); impl NeonF32Butterfly3 { #[inline(always)] @@ -563,8 +533,7 @@ pub struct NeonF64Butterfly3 { twiddle2im: float64x2_t, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly3); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly3, 3, |this: &NeonF64Butterfly3<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly3, 3, |this: &NeonF64Butterfly3<_>| this .direction); impl NeonF64Butterfly3 { #[inline(always)] @@ -635,8 +604,7 @@ pub struct NeonF32Butterfly4 { rotate: Rotate90F32, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly4); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly4, 4, |this: &NeonF32Butterfly4<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly4, 4, |this: &NeonF32Butterfly4<_>| this .direction); impl NeonF32Butterfly4 { #[inline(always)] @@ -756,8 +724,7 @@ pub struct NeonF64Butterfly4 { rotate: Rotate90F64, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly4); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly4, 4, |this: &NeonF64Butterfly4<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly4, 4, |this: &NeonF64Butterfly4<_>| this .direction); impl NeonF64Butterfly4 { #[inline(always)] @@ -837,8 +804,7 @@ pub struct NeonF32Butterfly5 { twiddle2im: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly5); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly5, 5, |this: &NeonF32Butterfly5<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly5, 5, |this: &NeonF32Butterfly5<_>| this .direction); impl NeonF32Butterfly5 { #[inline(always)] @@ -1003,8 +969,7 @@ pub struct NeonF64Butterfly5 { twiddle2im: float64x2_t, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly5); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly5, 5, |this: &NeonF64Butterfly5<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly5, 5, |this: &NeonF64Butterfly5<_>| this .direction); impl NeonF64Butterfly5 { #[inline(always)] @@ -1099,13 +1064,12 @@ impl NeonF64Butterfly5 { // pub struct NeonF32Butterfly6 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF32Butterfly3, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly6); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly6, 6, |this: &NeonF32Butterfly6<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly6, 6, |this: &NeonF32Butterfly6<_>| this + .bf3 .direction); impl NeonF32Butterfly6 { #[inline(always)] @@ -1114,7 +1078,6 @@ impl NeonF32Butterfly6 { let bf3 = NeonF32Butterfly3::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, } @@ -1216,13 +1179,12 @@ impl NeonF32Butterfly6 { // pub struct NeonF64Butterfly6 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF64Butterfly3, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly6); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly6, 6, |this: &NeonF64Butterfly6<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly6, 6, |this: &NeonF64Butterfly6<_>| this + .bf3 .direction); impl NeonF64Butterfly6 { #[inline(always)] @@ -1231,7 +1193,6 @@ impl NeonF64Butterfly6 { let bf3 = NeonF64Butterfly3::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, } @@ -1286,13 +1247,12 @@ impl NeonF64Butterfly6 { pub struct NeonF32Butterfly8 { root2: float32x4_t, root2_dual: float32x4_t, - direction: FftDirection, bf4: NeonF32Butterfly4, rotate90: Rotate90F32, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly8); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly8, 8, |this: &NeonF32Butterfly8<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly8, 8, |this: &NeonF32Butterfly8<_>| this + .bf4 .direction); impl NeonF32Butterfly8 { #[inline(always)] @@ -1310,7 +1270,6 @@ impl NeonF32Butterfly8 { Self { root2, root2_dual, - direction, bf4, rotate90, } @@ -1419,13 +1378,12 @@ impl NeonF32Butterfly8 { pub struct NeonF64Butterfly8 { root2: float64x2_t, - direction: FftDirection, bf4: NeonF64Butterfly4, rotate90: Rotate90F64, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly8); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly8, 8, |this: &NeonF64Butterfly8<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly8, 8, |this: &NeonF64Butterfly8<_>| this + .bf4 .direction); impl NeonF64Butterfly8 { #[inline(always)] @@ -1440,7 +1398,6 @@ impl NeonF64Butterfly8 { }; Self { root2, - direction, bf4, rotate90, } @@ -1499,7 +1456,6 @@ impl NeonF64Butterfly8 { // /_/ |____/_____|_.__/|_|\__| // pub struct NeonF32Butterfly9 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF32Butterfly3, twiddle1: float32x4_t, @@ -1507,8 +1463,8 @@ pub struct NeonF32Butterfly9 { twiddle4: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly9); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly9, 9, |this: &NeonF32Butterfly9<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly9, 9, |this: &NeonF32Butterfly9<_>| this + .bf3 .direction); impl NeonF32Butterfly9 { #[inline(always)] @@ -1523,7 +1479,6 @@ impl NeonF32Butterfly9 { let twiddle4 = unsafe { vld1q_f32([tw4.re, tw4.im, tw4.re, tw4.im].as_ptr()) }; Self { - direction, _phantom: std::marker::PhantomData, bf3, twiddle1, @@ -1628,7 +1583,6 @@ impl NeonF32Butterfly9 { // pub struct NeonF64Butterfly9 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF64Butterfly3, twiddle1: float64x2_t, @@ -1636,8 +1590,8 @@ pub struct NeonF64Butterfly9 { twiddle4: float64x2_t, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly9); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly9, 9, |this: &NeonF64Butterfly9<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly9, 9, |this: &NeonF64Butterfly9<_>| this + .bf3 .direction); impl NeonF64Butterfly9 { #[inline(always)] @@ -1652,7 +1606,6 @@ impl NeonF64Butterfly9 { let twiddle4 = unsafe { vld1q_f64([tw4.re, tw4.im].as_ptr()) }; Self { - direction, _phantom: std::marker::PhantomData, bf3, twiddle1, @@ -1703,13 +1656,12 @@ impl NeonF64Butterfly9 { // pub struct NeonF32Butterfly10 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf5: NeonF32Butterfly5, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly10); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly10, 10, |this: &NeonF32Butterfly10<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly10, 10, |this: &NeonF32Butterfly10<_>| this + .bf5 .direction); impl NeonF32Butterfly10 { #[inline(always)] @@ -1717,7 +1669,6 @@ impl NeonF32Butterfly10 { assert_f32::(); let bf5 = NeonF32Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf5, } @@ -1820,14 +1771,13 @@ impl NeonF32Butterfly10 { // pub struct NeonF64Butterfly10 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf2: NeonF64Butterfly2, bf5: NeonF64Butterfly5, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly10); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly10, 10, |this: &NeonF64Butterfly10<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly10, 10, |this: &NeonF64Butterfly10<_>| this + .bf5 .direction); impl NeonF64Butterfly10 { #[inline(always)] @@ -1836,7 +1786,6 @@ impl NeonF64Butterfly10 { let bf2 = NeonF64Butterfly2::new(direction); let bf5 = NeonF64Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf2, bf5, @@ -1889,14 +1838,13 @@ impl NeonF64Butterfly10 { // pub struct NeonF32Butterfly12 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF32Butterfly3, bf4: NeonF32Butterfly4, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly12); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly12, 12, |this: &NeonF32Butterfly12<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly12, 12, |this: &NeonF32Butterfly12<_>| this + .bf4 .direction); impl NeonF32Butterfly12 { #[inline(always)] @@ -1905,7 +1853,6 @@ impl NeonF32Butterfly12 { let bf3 = NeonF32Butterfly3::new(direction); let bf4 = NeonF32Butterfly4::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf4, @@ -2026,14 +1973,13 @@ impl NeonF32Butterfly12 { // pub struct NeonF64Butterfly12 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF64Butterfly3, bf4: NeonF64Butterfly4, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly12); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly12, 12, |this: &NeonF64Butterfly12<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly12, 12, |this: &NeonF64Butterfly12<_>| this + .bf4 .direction); impl NeonF64Butterfly12 { #[inline(always)] @@ -2042,7 +1988,6 @@ impl NeonF64Butterfly12 { let bf3 = NeonF64Butterfly3::new(direction); let bf4 = NeonF64Butterfly4::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf4, @@ -2095,14 +2040,13 @@ impl NeonF64Butterfly12 { // |_|____/ |____/_____|_.__/|_|\__| // pub struct NeonF32Butterfly15 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF32Butterfly3, bf5: NeonF32Butterfly5, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly15); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly15, 15, |this: &NeonF32Butterfly15<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly15, 15, |this: &NeonF32Butterfly15<_>| this + .bf3 .direction); impl NeonF32Butterfly15 { #[inline(always)] @@ -2111,7 +2055,6 @@ impl NeonF32Butterfly15 { let bf3 = NeonF32Butterfly3::new(direction); let bf5 = NeonF32Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf5, @@ -2231,14 +2174,13 @@ impl NeonF32Butterfly15 { // pub struct NeonF64Butterfly15 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: NeonF64Butterfly3, bf5: NeonF64Butterfly5, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly15); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly15, 15, |this: &NeonF64Butterfly15<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly15, 15, |this: &NeonF64Butterfly15<_>| this + .bf3 .direction); impl NeonF64Butterfly15 { #[inline(always)] @@ -2247,7 +2189,6 @@ impl NeonF64Butterfly15 { let bf3 = NeonF64Butterfly3::new(direction); let bf5 = NeonF64Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf5, @@ -2310,8 +2251,7 @@ pub struct NeonF32Butterfly16 { twiddle9: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly16); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16<_>| this .bf4 .direction); impl NeonF32Butterfly16 { @@ -2477,8 +2417,7 @@ pub struct NeonF64Butterfly16 { twiddle9: float64x2_t, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly16); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly16, 16, |this: &NeonF64Butterfly16<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly16, 16, |this: &NeonF64Butterfly16<_>| this .bf4 .direction); impl NeonF64Butterfly16 { @@ -2586,8 +2525,7 @@ pub struct NeonF32Butterfly24 { twiddle10: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly24); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly24, 24, |this: &NeonF32Butterfly24<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly24, 24, |this: &NeonF32Butterfly24<_>| this .bf4 .direction); impl NeonF32Butterfly24 { @@ -2791,8 +2729,7 @@ pub struct NeonF64Butterfly24 { twiddle10: float64x2_t, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly24); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly24, 24, |this: &NeonF64Butterfly24<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly24, 24, |this: &NeonF64Butterfly24<_>| this .bf4 .direction); impl NeonF64Butterfly24 { @@ -2924,8 +2861,7 @@ pub struct NeonF32Butterfly32 { twiddle21: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly32); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly32, 32, |this: &NeonF32Butterfly32<_>| this +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly32, 32, |this: &NeonF32Butterfly32<_>| this .bf8 .bf4 .direction); @@ -3166,8 +3102,7 @@ pub struct NeonF64Butterfly32 { twiddle21: float64x2_t, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly32); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly32, 32, |this: &NeonF64Butterfly32<_>| this +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly32, 32, |this: &NeonF64Butterfly32<_>| this .bf8 .bf4 .direction); diff --git a/src/neon/neon_common.rs b/src/neon/neon_common.rs index 635cf30..df31e9f 100644 --- a/src/neon/neon_common.rs +++ b/src/neon/neon_common.rs @@ -1,5 +1,10 @@ use std::any::TypeId; +use crate::fft_helper::{ + fft_helper_immut, fft_helper_immut_unroll2x, fft_helper_inplace, fft_helper_inplace_unroll2x, + fft_helper_outofplace, fft_helper_outofplace_unroll2x, +}; + // Helper function to assert we have the right float type pub fn assert_f32() { let id_f32 = TypeId::of::(); @@ -63,29 +68,17 @@ macro_rules! boilerplate_fft_neon_oop { output: &mut [Complex], _scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = unsafe { - array_utils::iter_chunks_zipped( - input, - output, + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::neon_common::neon_fft_helper_immut( + simd_input, + simd_output, + &mut [], self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), - ) - }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + 0, + |input, output, _| self.perform_fft_immut(input, output, &mut []), + ); } } fn process_outofplace_with_scratch( @@ -94,66 +87,33 @@ macro_rules! boilerplate_fft_neon_oop { output: &mut [Complex], _scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let result = unsafe { - array_utils::iter_chunks_zipped_mut( - input, - output, + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::neon_common::neon_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) - }, - ) - }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + 0, + |input, output, _| self.perform_fft_out_of_place(input, output, &mut []), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + let simd_scratch = crate::array_utils::workaround_transmute_mut(scratch); + super::neon_common::neon_fft_helper_inplace( + simd_buffer, + simd_scratch, self.len(), - buffer.len(), self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = unsafe { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_out_of_place(chunk, scratch, &mut []); - chunk.copy_from_slice(scratch); - }) - }; - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); + |chunk, scratch| { + self.perform_fft_out_of_place(chunk, scratch, &mut []); + chunk.copy_from_slice(scratch); + }, + ) } } #[inline(always)] @@ -184,112 +144,95 @@ macro_rules! boilerplate_fft_neon_oop { }; } -/* Not used now, but maybe later for the mixed radixes etc -macro_rules! boilerplate_sse_fft { - ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { - impl Fft for $struct_name { - fn process_outofplace_with_scratch( - &self, - input: &mut [Complex], - output: &mut [Complex], - scratch: &mut [Complex], - ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_outofplace_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the NEON target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "neon")] +pub unsafe fn neon_fft_helper_immut( + input: &[T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&[T], &mut [T], &mut [T]), +) { + fft_helper_immut( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped_mut( - input, - output, - self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) - }, - ); +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the NEON target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "neon")] +pub unsafe fn neon_fft_helper_outofplace( + input: &mut [T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T], &mut [T]), +) { + fft_helper_outofplace( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - } - } - fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the NEON target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "neon")] +pub unsafe fn neon_fft_helper_inplace( + buffer: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_inplace(buffer, scratch, chunk_size, required_scratch, chunk_fn) +} - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the NEON target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "neon")] +pub unsafe fn neon_fft_helper_immut_unroll2x( + input: &[T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&[T], &mut [T]), + chunk_fn: impl FnMut(&[T], &mut [T]), +) { + fft_helper_immut_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn) +} - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, scratch) - }); +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the NEON target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "neon")] +pub unsafe fn neon_fft_helper_outofplace_unroll2x( + input: &mut [T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T], &mut [T]), + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_outofplace_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn) +} - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - } - } - #[inline(always)] - fn get_inplace_scratch_len(&self) -> usize { - $inplace_scratch_len_fn(self) - } - #[inline(always)] - fn get_outofplace_scratch_len(&self) -> usize { - $out_of_place_scratch_len_fn(self) - } - } - impl Length for $struct_name { - #[inline(always)] - fn len(&self) -> usize { - $len_fn(self) - } - } - impl Direction for $struct_name { - #[inline(always)] - fn fft_direction(&self) -> FftDirection { - self.direction - } - } - }; +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the NEON target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "neon")] +pub unsafe fn neon_fft_helper_inplace_unroll2x( + buffer: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T]), + chunk_fn: impl FnMut(&mut [T]), +) { + fft_helper_inplace_unroll2x(buffer, chunk_size, chunk2x_fn, chunk_fn) } -*/ diff --git a/src/neon/neon_prime_butterflies.rs b/src/neon/neon_prime_butterflies.rs index 8ef4810..4916547 100644 --- a/src/neon/neon_prime_butterflies.rs +++ b/src/neon/neon_prime_butterflies.rs @@ -5,10 +5,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -88,8 +85,7 @@ struct NeonF32Butterfly7 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly7); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly7, 7, |this: &NeonF32Butterfly7<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly7, 7, |this: &NeonF32Butterfly7<_>| this.direction); impl NeonF32Butterfly7 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -193,8 +189,7 @@ struct NeonF64Butterfly7 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly7); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly7, 7, |this: &NeonF64Butterfly7<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly7, 7, |this: &NeonF64Butterfly7<_>| this.direction); impl NeonF64Butterfly7 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -269,8 +264,7 @@ struct NeonF32Butterfly11 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly11); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly11, 11, |this: &NeonF32Butterfly11<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly11, 11, |this: &NeonF32Butterfly11<_>| this.direction); impl NeonF32Butterfly11 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -424,8 +418,7 @@ struct NeonF64Butterfly11 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly11); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly11, 11, |this: &NeonF64Butterfly11<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly11, 11, |this: &NeonF64Butterfly11<_>| this.direction); impl NeonF64Butterfly11 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -542,8 +535,7 @@ struct NeonF32Butterfly13 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly13); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly13, 13, |this: &NeonF32Butterfly13<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly13, 13, |this: &NeonF32Butterfly13<_>| this.direction); impl NeonF32Butterfly13 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -728,8 +720,7 @@ struct NeonF64Butterfly13 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly13); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly13, 13, |this: &NeonF64Butterfly13<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly13, 13, |this: &NeonF64Butterfly13<_>| this.direction); impl NeonF64Butterfly13 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -873,8 +864,7 @@ struct NeonF32Butterfly17 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly17); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly17, 17, |this: &NeonF32Butterfly17<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly17, 17, |this: &NeonF32Butterfly17<_>| this.direction); impl NeonF32Butterfly17 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -1133,8 +1123,7 @@ struct NeonF64Butterfly17 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly17); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly17, 17, |this: &NeonF64Butterfly17<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly17, 17, |this: &NeonF64Butterfly17<_>| this.direction); impl NeonF64Butterfly17 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -1344,8 +1333,7 @@ struct NeonF32Butterfly19 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly19); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly19, 19, |this: &NeonF32Butterfly19<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly19, 19, |this: &NeonF32Butterfly19<_>| this.direction); impl NeonF32Butterfly19 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -1647,8 +1635,7 @@ struct NeonF64Butterfly19 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly19); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly19, 19, |this: &NeonF64Butterfly19<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly19, 19, |this: &NeonF64Butterfly19<_>| this.direction); impl NeonF64Butterfly19 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -1897,8 +1884,7 @@ struct NeonF32Butterfly23 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly23); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly23, 23, |this: &NeonF32Butterfly23<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly23, 23, |this: &NeonF32Butterfly23<_>| this.direction); impl NeonF32Butterfly23 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -2298,8 +2284,7 @@ struct NeonF64Butterfly23 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly23); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly23, 23, |this: &NeonF64Butterfly23<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly23, 23, |this: &NeonF64Butterfly23<_>| this.direction); impl NeonF64Butterfly23 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -2638,8 +2623,7 @@ struct NeonF32Butterfly29 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly29); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly29, 29, |this: &NeonF32Butterfly29<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly29, 29, |this: &NeonF32Butterfly29<_>| this.direction); impl NeonF32Butterfly29 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -3216,8 +3200,7 @@ struct NeonF64Butterfly29 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly29); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly29, 29, |this: &NeonF64Butterfly29<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly29, 29, |this: &NeonF64Butterfly29<_>| this.direction); impl NeonF64Butterfly29 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -3721,8 +3704,7 @@ struct NeonF32Butterfly31 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly31); -boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly31, 31, |this: &NeonF32Butterfly31<_>| this.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly31, 31, |this: &NeonF32Butterfly31<_>| this.direction); impl NeonF32Butterfly31 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] @@ -4366,8 +4348,7 @@ struct NeonF64Butterfly31 { _phantom: std::marker::PhantomData, } -boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly31); -boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly31, 31, |this: &NeonF64Butterfly31<_>| this.direction); +boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly31, 31, |this: &NeonF64Butterfly31<_>| this.direction); impl NeonF64Butterfly31 { /// Safety: The current machine must support the neon instruction set #[target_feature(enable = "neon")] diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index 2fbfe4e..fe7aee2 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -3,8 +3,7 @@ use num_complex::Complex; use std::any::TypeId; use std::sync::Arc; -use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{bitreversed_transpose, workaround_transmute_mut}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/neon/neon_vector.rs b/src/neon/neon_vector.rs index 3ab9bf9..035cf4d 100644 --- a/src/neon/neon_vector.rs +++ b/src/neon/neon_vector.rs @@ -132,7 +132,6 @@ pub struct Rotation90(V); // A trait to hold the BVectorType and COMPLEX_PER_VECTOR associated data pub trait NeonVector: Copy + Debug + Send + Sync { - const SCALAR_PER_VECTOR: usize; const COMPLEX_PER_VECTOR: usize; type ScalarType: NeonNum; @@ -145,6 +144,9 @@ pub trait NeonVector: Copy + Debug + Send + Sync { // stores of complex numbers unsafe fn store_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); + + // Keep this around even though it's unused - research went into how to do it, keeping it ensures that research doesn't need to be repeated + #[allow(unused)] unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); // math ops @@ -180,7 +182,6 @@ pub trait NeonVector: Copy + Debug + Send + Sync { } impl NeonVector for float32x4_t { - const SCALAR_PER_VECTOR: usize = 4; const COMPLEX_PER_VECTOR: usize = 2; type ScalarType = f32; @@ -315,7 +316,6 @@ impl NeonVector for float32x4_t { } impl NeonVector for float64x2_t { - const SCALAR_PER_VECTOR: usize = 2; const COMPLEX_PER_VECTOR: usize = 1; type ScalarType = f64; @@ -518,8 +518,6 @@ pub trait NeonArrayMut: NeonArray + DerefMut { unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize); // Store the low complex number from a NEON vector to the array. unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize); - // Store the high complex number from a NEON vector to the array. - unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize); } impl NeonArrayMut for &mut [Complex] { @@ -528,12 +526,6 @@ impl NeonArrayMut for &mut [Complex] { debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR); S::VectorType::store_complex(self.as_mut_ptr().add(index), vector) } - - #[inline(always)] - unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize) { - debug_assert!(self.len() >= index + 1); - S::VectorType::store_partial_hi_complex(self.as_mut_ptr().add(index), vector) - } #[inline(always)] unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize) { debug_assert!(self.len() >= index + 1); @@ -554,10 +546,6 @@ where unsafe fn store_partial_lo_complex(&mut self, vector: T::VectorType, index: usize) { self.output.store_partial_lo_complex(vector, index); } - #[inline(always)] - unsafe fn store_partial_hi_complex(&mut self, vector: T::VectorType, index: usize) { - self.output.store_partial_hi_complex(vector, index); - } } #[cfg(test)] diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 7ce8d21..c279662 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -3,10 +3,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; use crate::array_utils::DoubleBuf; -use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -25,155 +22,169 @@ unsafe fn pack_64(a: Complex) -> __m128d { #[allow(unused)] macro_rules! boilerplate_fft_sse_f32_butterfly { - ($struct_name:ident) => { - impl $struct_name { - #[target_feature(enable = "sse4.1")] - //#[inline(always)] - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); - } - - #[target_feature(enable = "sse4.1")] - //#[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_parallel_fft_contiguous(workaround_transmute_mut::<_, Complex>( - buffer, - )); - } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_fft_butterfly_multi( + ($struct_name:ident, $len:expr, $direction_fn:expr) => { + impl Fft for $struct_name { + fn process_immutable_with_scratch( &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - let len = buffer.len(); - let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { - self.perform_parallel_fft_butterfly(chunk) - }); - if alldone.is_err() && buffer.len() >= self.len() { - self.perform_fft_butterfly(&mut buffer[len - self.len()..]); + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_immut_unroll2x( + simd_input, + simd_output, + self.len(), + |input, output| { + self.perform_parallel_fft_contiguous(DoubleBuf { input, output }) + }, + |input, output| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } - Ok(()) } - - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( + fn process_outofplace_with_scratch( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], - ) -> Result<(), ()> { - let len = input.len(); - let alldone = array_utils::iter_chunks_zipped( - input, - output, - 2 * self.len(), - |in_chunk, out_chunk| { - let input_slice = crate::array_utils::workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_parallel_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }, - ); - if alldone.is_err() && input.len() >= self.len() { - let input_slice = crate::array_utils::workaround_transmute(input); - let output_slice = workaround_transmute_mut(output); - self.perform_fft_contiguous(DoubleBuf { - input: &input_slice[len - self.len()..], - output: &mut output_slice[len - self.len()..], - }) + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_outofplace_unroll2x( + simd_input, + simd_output, + self.len(), + |input, output| { + self.perform_parallel_fft_contiguous(DoubleBuf { input, output }) + }, + |input, output| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); + } + } + fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::sse_common::sse_fft_helper_inplace_unroll2x( + simd_buffer, + self.len(), + |chunk| self.perform_parallel_fft_contiguous(chunk), + |chunk| self.perform_fft_contiguous(chunk), + ) } - Ok(()) + } + #[inline(always)] + fn get_inplace_scratch_len(&self) -> usize { + 0 + } + #[inline(always)] + fn get_outofplace_scratch_len(&self) -> usize { + 0 + } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } + } + impl Length for $struct_name { + #[inline(always)] + fn len(&self) -> usize { + $len + } + } + impl Direction for $struct_name { + #[inline(always)] + fn fft_direction(&self) -> FftDirection { + $direction_fn(self) } } }; } macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { - ($struct_name:ident) => { - impl $struct_name { - // Do a single fft - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); - } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_fft_butterfly_multi( + ($struct_name:ident, $len:expr, $direction_fn:expr) => { + impl Fft for $struct_name { + fn process_immutable_with_scratch( &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_butterfly(chunk) - }) + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); + } } - - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( + fn process_outofplace_with_scratch( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }) + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); + } } - } - }; -} - -macro_rules! boilerplate_fft_sse_f64_butterfly { - ($struct_name:ident) => { - impl $struct_name { - // Do a single fft - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); + fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::sse_common::sse_fft_helper_inplace( + simd_buffer, + &mut [], + self.len(), + 0, + |chunk, _| self.perform_fft_contiguous(chunk), + ) + } } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_fft_butterfly_multi( - &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_butterfly(chunk) - }) + #[inline(always)] + fn get_inplace_scratch_len(&self) -> usize { + 0 } - - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - #[target_feature(enable = "sse4.1")] - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( - &self, - input: &[Complex], - output: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = crate::array_utils::workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }) + #[inline(always)] + fn get_outofplace_scratch_len(&self) -> usize { + 0 + } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } + } + impl Length for $struct_name { + #[inline(always)] + fn len(&self) -> usize { + $len + } + } + impl Direction for $struct_name { + #[inline(always)] + fn fft_direction(&self) -> FftDirection { + $direction_fn(self) } } }; } -#[allow(unused)] -macro_rules! boilerplate_fft_sse_common_butterfly { +macro_rules! boilerplate_fft_sse_f64_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { fn process_immutable_with_scratch( @@ -182,17 +193,17 @@ macro_rules! boilerplate_fft_sse_common_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } } fn process_outofplace_with_scratch( @@ -201,32 +212,29 @@ macro_rules! boilerplate_fft_sse_common_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { - if buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let result = unsafe { self.perform_fft_butterfly_multi(buffer) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::sse_common::sse_fft_helper_inplace( + simd_buffer, + &mut [], + self.len(), + 0, + |chunk, _| self.perform_fft_contiguous(chunk), + ) } } #[inline(always)] @@ -269,8 +277,7 @@ pub struct SseF32Butterfly1 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly1); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly1, 1, |this: &SseF32Butterfly1<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly1, 1, |this: &SseF32Butterfly1<_>| this .direction); impl SseF32Butterfly1 { #[inline(always)] @@ -306,8 +313,7 @@ pub struct SseF64Butterfly1 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly1); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly1, 1, |this: &SseF64Butterfly1<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly1, 1, |this: &SseF64Butterfly1<_>| this .direction); impl SseF64Butterfly1 { #[inline(always)] @@ -337,8 +343,7 @@ pub struct SseF32Butterfly2 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly2); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly2, 2, |this: &SseF32Butterfly2<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly2, 2, |this: &SseF32Butterfly2<_>| this .direction); impl SseF32Butterfly2 { #[inline(always)] @@ -428,8 +433,7 @@ pub struct SseF64Butterfly2 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly2); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly2, 2, |this: &SseF64Butterfly2<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly2, 2, |this: &SseF64Butterfly2<_>| this .direction); impl SseF64Butterfly2 { #[inline(always)] @@ -485,8 +489,7 @@ pub struct SseF32Butterfly3 { twiddle1im: __m128, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly3); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly3, 3, |this: &SseF32Butterfly3<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly3, 3, |this: &SseF32Butterfly3<_>| this .direction); impl SseF32Butterfly3 { #[inline(always)] @@ -598,8 +601,7 @@ pub struct SseF64Butterfly3 { twiddle1im: __m128d, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly3); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly3, 3, |this: &SseF64Butterfly3<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly3, 3, |this: &SseF64Butterfly3<_>| this .direction); impl SseF64Butterfly3 { #[inline(always)] @@ -671,8 +673,7 @@ pub struct SseF32Butterfly4 { rotate: Rotate90F32, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly4); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly4, 4, |this: &SseF32Butterfly4<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly4, 4, |this: &SseF32Butterfly4<_>| this .direction); impl SseF32Butterfly4 { #[inline(always)] @@ -786,8 +787,7 @@ pub struct SseF64Butterfly4 { rotate: Rotate90F64, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly4); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly4, 4, |this: &SseF64Butterfly4<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly4, 4, |this: &SseF64Butterfly4<_>| this .direction); impl SseF64Butterfly4 { #[inline(always)] @@ -867,8 +867,7 @@ pub struct SseF32Butterfly5 { twiddle2im: __m128, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly5); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly5, 5, |this: &SseF32Butterfly5<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly5, 5, |this: &SseF32Butterfly5<_>| this .direction); impl SseF32Butterfly5 { #[inline(always)] @@ -1034,8 +1033,7 @@ pub struct SseF64Butterfly5 { twiddle2im: __m128d, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly5); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly5, 5, |this: &SseF64Butterfly5<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly5, 5, |this: &SseF64Butterfly5<_>| this .direction); impl SseF64Butterfly5 { #[inline(always)] @@ -1135,8 +1133,7 @@ pub struct SseF32Butterfly6 { bf3: SseF32Butterfly3, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly6); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly6, 6, |this: &SseF32Butterfly6<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly6, 6, |this: &SseF32Butterfly6<_>| this .direction); impl SseF32Butterfly6 { #[inline(always)] @@ -1244,13 +1241,12 @@ impl SseF32Butterfly6 { // pub struct SseF64Butterfly6 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: SseF64Butterfly3, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly6); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly6, 6, |this: &SseF64Butterfly6<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly6, 6, |this: &SseF64Butterfly6<_>| this + .bf3 .direction); impl SseF64Butterfly6 { #[inline(always)] @@ -1259,7 +1255,6 @@ impl SseF64Butterfly6 { let bf3 = SseF64Butterfly3::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, } @@ -1319,8 +1314,7 @@ pub struct SseF32Butterfly8 { rotate90: Rotate90F32, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly8); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly8, 8, |this: &SseF32Butterfly8<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly8, 8, |this: &SseF32Butterfly8<_>| this .direction); impl SseF32Butterfly8 { #[inline(always)] @@ -1445,8 +1439,7 @@ pub struct SseF64Butterfly8 { bf4: SseF64Butterfly4, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly8); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly8, 8, |this: &SseF64Butterfly8<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly8, 8, |this: &SseF64Butterfly8<_>| this .bf4 .direction); impl SseF64Butterfly8 { @@ -1515,8 +1508,7 @@ pub struct SseF32Butterfly9 { twiddle4: __m128, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly9); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly9, 9, |this: &SseF32Butterfly9<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly9, 9, |this: &SseF32Butterfly9<_>| this .direction); impl SseF32Butterfly9 { #[inline(always)] @@ -1629,7 +1621,6 @@ impl SseF32Butterfly9 { // pub struct SseF64Butterfly9 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: SseF64Butterfly3, twiddle1: __m128d, @@ -1637,8 +1628,8 @@ pub struct SseF64Butterfly9 { twiddle4: __m128d, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly9); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly9, 9, |this: &SseF64Butterfly9<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly9, 9, |this: &SseF64Butterfly9<_>| this + .bf3 .direction); impl SseF64Butterfly9 { #[inline(always)] @@ -1652,7 +1643,6 @@ impl SseF64Butterfly9 { let twiddle2 = unsafe { _mm_set_pd(tw2.im, tw2.re) }; let twiddle4 = unsafe { _mm_set_pd(tw4.im, tw4.re) }; Self { - direction, _phantom: std::marker::PhantomData, bf3, twiddle1, @@ -1708,8 +1698,7 @@ pub struct SseF32Butterfly10 { bf5: SseF32Butterfly5, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly10); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly10, 10, |this: &SseF32Butterfly10<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly10, 10, |this: &SseF32Butterfly10<_>| this .direction); impl SseF32Butterfly10 { #[inline(always)] @@ -1814,14 +1803,13 @@ impl SseF32Butterfly10 { // pub struct SseF64Butterfly10 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf2: SseF64Butterfly2, bf5: SseF64Butterfly5, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly10); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly10, 10, |this: &SseF64Butterfly10<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly10, 10, |this: &SseF64Butterfly10<_>| this + .bf5 .direction); impl SseF64Butterfly10 { #[inline(always)] @@ -1830,7 +1818,6 @@ impl SseF64Butterfly10 { let bf2 = SseF64Butterfly2::new(direction); let bf5 = SseF64Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf2, bf5, @@ -1889,8 +1876,7 @@ pub struct SseF32Butterfly12 { bf4: SseF32Butterfly4, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly12); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly12, 12, |this: &SseF32Butterfly12<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly12, 12, |this: &SseF32Butterfly12<_>| this .direction); impl SseF32Butterfly12 { #[inline(always)] @@ -2014,14 +2000,13 @@ impl SseF32Butterfly12 { // pub struct SseF64Butterfly12 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: SseF64Butterfly3, bf4: SseF64Butterfly4, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly12); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly12, 12, |this: &SseF64Butterfly12<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly12, 12, |this: &SseF64Butterfly12<_>| this + .bf3 .direction); impl SseF64Butterfly12 { #[inline(always)] @@ -2030,7 +2015,6 @@ impl SseF64Butterfly12 { let bf3 = SseF64Butterfly3::new(direction); let bf4 = SseF64Butterfly4::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf4, @@ -2089,8 +2073,7 @@ pub struct SseF32Butterfly15 { bf5: SseF32Butterfly5, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly15); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly15, 15, |this: &SseF32Butterfly15<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly15, 15, |this: &SseF32Butterfly15<_>| this .direction); impl SseF32Butterfly15 { #[inline(always)] @@ -2213,14 +2196,13 @@ impl SseF32Butterfly15 { // pub struct SseF64Butterfly15 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: SseF64Butterfly3, bf5: SseF64Butterfly5, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly15); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly15, 15, |this: &SseF64Butterfly15<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly15, 15, |this: &SseF64Butterfly15<_>| this + .bf3 .direction); impl SseF64Butterfly15 { #[inline(always)] @@ -2229,7 +2211,6 @@ impl SseF64Butterfly15 { let bf3 = SseF64Butterfly3::new(direction); let bf5 = SseF64Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf5, @@ -2292,10 +2273,9 @@ pub struct SseF32Butterfly16 { twiddle9: __m128, } -boilerplate_fft_sse_f32_butterfly_noparallel!(SseF32Butterfly16); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16<_>| this - .bf4 - .direction); +boilerplate_fft_sse_f32_butterfly_noparallel!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16< + _, +>| this.bf4.direction); impl SseF32Butterfly16 { pub fn new(direction: FftDirection) -> Self { assert_f32::(); @@ -2458,8 +2438,7 @@ pub struct SseF64Butterfly16 { twiddle9: __m128d, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly16); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly16, 16, |this: &SseF64Butterfly16<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly16, 16, |this: &SseF64Butterfly16<_>| this .bf4 .direction); impl SseF64Butterfly16 { @@ -2567,8 +2546,7 @@ pub struct SseF32Butterfly24 { twiddle10: __m128, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly24); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly24, 24, |this: &SseF32Butterfly24<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly24, 24, |this: &SseF32Butterfly24<_>| this .bf4 .direction); impl SseF32Butterfly24 { @@ -2769,8 +2747,7 @@ pub struct SseF64Butterfly24 { twiddle10: __m128d, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly24); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly24, 24, |this: &SseF64Butterfly24<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly24, 24, |this: &SseF64Butterfly24<_>| this .bf4 .direction); impl SseF64Butterfly24 { @@ -2902,8 +2879,7 @@ pub struct SseF32Butterfly32 { twiddle21: __m128, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly32); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly32, 32, |this: &SseF32Butterfly32<_>| this +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly32, 32, |this: &SseF32Butterfly32<_>| this .bf8 .bf4 .direction); @@ -3141,8 +3117,7 @@ pub struct SseF64Butterfly32 { twiddle21: __m128d, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly32); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly32, 32, |this: &SseF64Butterfly32<_>| this +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly32, 32, |this: &SseF64Butterfly32<_>| this .bf8 .bf4 .direction); diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 86593ab..78b2673 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -1,5 +1,10 @@ use std::any::TypeId; +use crate::fft_helper::{ + fft_helper_immut, fft_helper_immut_unroll2x, fft_helper_inplace, fft_helper_inplace_unroll2x, + fft_helper_outofplace, fft_helper_outofplace_unroll2x, +}; + // Helper function to assert we have the right float type pub fn assert_f32() { let id_f32 = TypeId::of::(); @@ -63,29 +68,17 @@ macro_rules! boilerplate_fft_sse_oop { output: &mut [Complex], _scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = unsafe { - array_utils::iter_chunks_zipped( - input, - output, + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_immut( + simd_input, + simd_output, + &mut [], self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), - ) - }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + 0, + |input, output, _| self.perform_fft_immut(input, output, &mut []), + ); } } fn process_outofplace_with_scratch( @@ -94,66 +87,33 @@ macro_rules! boilerplate_fft_sse_oop { output: &mut [Complex], _scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let result = unsafe { - array_utils::iter_chunks_zipped( - input, - output, + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::sse_common::sse_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) - }, - ) - }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + 0, + |input, output, _| self.perform_fft_out_of_place(input, output, &mut []), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + let simd_scratch = crate::array_utils::workaround_transmute_mut(scratch); + super::sse_common::sse_fft_helper_inplace( + simd_buffer, + simd_scratch, self.len(), - buffer.len(), self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = unsafe { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_out_of_place(chunk, scratch, &mut []); - chunk.copy_from_slice(scratch); - }) - }; - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); + |chunk, scratch| { + self.perform_fft_out_of_place(chunk, scratch, &mut []); + chunk.copy_from_slice(scratch); + }, + ) } } #[inline(always)] @@ -184,112 +144,95 @@ macro_rules! boilerplate_fft_sse_oop { }; } -/* Not used now, but maybe later for the mixed radixes etc -macro_rules! boilerplate_sse_fft { - ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { - impl Fft for $struct_name { - fn process_outofplace_with_scratch( - &self, - input: &mut [Complex], - output: &mut [Complex], - scratch: &mut [Complex], - ) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_outofplace_scratch_len(); - if scratch.len() < required_scratch - || input.len() < self.len() - || output.len() != input.len() - { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the SSE target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_fft_helper_immut( + input: &[T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&[T], &mut [T], &mut [T]), +) { + fft_helper_immut( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped_mut( - input, - output, - self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) - }, - ); +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the SSE target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_fft_helper_outofplace( + input: &mut [T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T], &mut [T]), +) { + fft_helper_outofplace( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( - self.len(), - input.len(), - output.len(), - self.get_outofplace_scratch_len(), - scratch.len(), - ); - } - } - fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the SSE target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_fft_helper_inplace( + buffer: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_inplace(buffer, scratch, chunk_size, required_scratch, chunk_fn) +} - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the SSE target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_fft_helper_immut_unroll2x( + input: &[T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&[T], &mut [T]), + chunk_fn: impl FnMut(&[T], &mut [T]), +) { + fft_helper_immut_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn) +} - let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_inplace(chunk, scratch) - }); +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the SSE target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_fft_helper_outofplace_unroll2x( + input: &mut [T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T], &mut [T]), + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_outofplace_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn) +} - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - } - } - #[inline(always)] - fn get_inplace_scratch_len(&self) -> usize { - $inplace_scratch_len_fn(self) - } - #[inline(always)] - fn get_outofplace_scratch_len(&self) -> usize { - $out_of_place_scratch_len_fn(self) - } - } - impl Length for $struct_name { - #[inline(always)] - fn len(&self) -> usize { - $len_fn(self) - } - } - impl Direction for $struct_name { - #[inline(always)] - fn fft_direction(&self) -> FftDirection { - self.direction - } - } - }; +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the SSE target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "sse4.1")] +pub unsafe fn sse_fft_helper_inplace_unroll2x( + buffer: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T]), + chunk_fn: impl FnMut(&mut [T]), +) { + fft_helper_inplace_unroll2x(buffer, chunk_size, chunk2x_fn, chunk_fn) } -*/ diff --git a/src/sse/sse_prime_butterflies.rs b/src/sse/sse_prime_butterflies.rs index ac9d500..0de599a 100644 --- a/src/sse/sse_prime_butterflies.rs +++ b/src/sse/sse_prime_butterflies.rs @@ -5,10 +5,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -88,8 +85,7 @@ struct SseF32Butterfly7 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly7); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly7, 7, |this: &SseF32Butterfly7<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly7, 7, |this: &SseF32Butterfly7<_>| this.direction); impl SseF32Butterfly7 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -193,8 +189,7 @@ struct SseF64Butterfly7 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly7); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly7, 7, |this: &SseF64Butterfly7<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly7, 7, |this: &SseF64Butterfly7<_>| this.direction); impl SseF64Butterfly7 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -269,8 +264,7 @@ struct SseF32Butterfly11 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly11); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly11, 11, |this: &SseF32Butterfly11<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly11, 11, |this: &SseF32Butterfly11<_>| this.direction); impl SseF32Butterfly11 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -424,8 +418,7 @@ struct SseF64Butterfly11 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly11); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly11, 11, |this: &SseF64Butterfly11<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly11, 11, |this: &SseF64Butterfly11<_>| this.direction); impl SseF64Butterfly11 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -542,8 +535,7 @@ struct SseF32Butterfly13 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly13); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly13, 13, |this: &SseF32Butterfly13<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly13, 13, |this: &SseF32Butterfly13<_>| this.direction); impl SseF32Butterfly13 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -728,8 +720,7 @@ struct SseF64Butterfly13 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly13); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly13, 13, |this: &SseF64Butterfly13<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly13, 13, |this: &SseF64Butterfly13<_>| this.direction); impl SseF64Butterfly13 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -873,8 +864,7 @@ struct SseF32Butterfly17 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly17); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly17, 17, |this: &SseF32Butterfly17<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly17, 17, |this: &SseF32Butterfly17<_>| this.direction); impl SseF32Butterfly17 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -1133,8 +1123,7 @@ struct SseF64Butterfly17 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly17); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly17, 17, |this: &SseF64Butterfly17<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly17, 17, |this: &SseF64Butterfly17<_>| this.direction); impl SseF64Butterfly17 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -1344,8 +1333,7 @@ struct SseF32Butterfly19 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly19); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly19, 19, |this: &SseF32Butterfly19<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly19, 19, |this: &SseF32Butterfly19<_>| this.direction); impl SseF32Butterfly19 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -1647,8 +1635,7 @@ struct SseF64Butterfly19 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly19); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly19, 19, |this: &SseF64Butterfly19<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly19, 19, |this: &SseF64Butterfly19<_>| this.direction); impl SseF64Butterfly19 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -1897,8 +1884,7 @@ struct SseF32Butterfly23 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly23); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly23, 23, |this: &SseF32Butterfly23<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly23, 23, |this: &SseF32Butterfly23<_>| this.direction); impl SseF32Butterfly23 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -2298,8 +2284,7 @@ struct SseF64Butterfly23 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly23); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly23, 23, |this: &SseF64Butterfly23<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly23, 23, |this: &SseF64Butterfly23<_>| this.direction); impl SseF64Butterfly23 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -2638,8 +2623,7 @@ struct SseF32Butterfly29 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly29); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly29, 29, |this: &SseF32Butterfly29<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly29, 29, |this: &SseF32Butterfly29<_>| this.direction); impl SseF32Butterfly29 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -3216,8 +3200,7 @@ struct SseF64Butterfly29 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly29); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly29, 29, |this: &SseF64Butterfly29<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly29, 29, |this: &SseF64Butterfly29<_>| this.direction); impl SseF64Butterfly29 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -3721,8 +3704,7 @@ struct SseF32Butterfly31 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly31); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly31, 31, |this: &SseF32Butterfly31<_>| this.direction); +boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly31, 31, |this: &SseF32Butterfly31<_>| this.direction); impl SseF32Butterfly31 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] @@ -4366,8 +4348,7 @@ struct SseF64Butterfly31 { _phantom: std::marker::PhantomData, } -boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly31); -boilerplate_fft_sse_common_butterfly!(SseF64Butterfly31, 31, |this: &SseF64Butterfly31<_>| this.direction); +boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly31, 31, |this: &SseF64Butterfly31<_>| this.direction); impl SseF64Butterfly31 { /// Safety: The current machine must support the sse4.1 instruction set #[target_feature(enable = "sse4.1")] diff --git a/src/sse/sse_radix4.rs b/src/sse/sse_radix4.rs index 43eb397..c788b2d 100644 --- a/src/sse/sse_radix4.rs +++ b/src/sse/sse_radix4.rs @@ -3,8 +3,7 @@ use num_complex::Complex; use std::any::TypeId; use std::sync::Arc; -use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{bitreversed_transpose, workaround_transmute_mut}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/sse/sse_vector.rs b/src/sse/sse_vector.rs index 6470a50..66daa4c 100644 --- a/src/sse/sse_vector.rs +++ b/src/sse/sse_vector.rs @@ -133,7 +133,6 @@ pub struct Rotation90(V); // A trait to hold the BVectorType and COMPLEX_PER_VECTOR associated data pub trait SseVector: Copy + Debug + Send + Sync { - const SCALAR_PER_VECTOR: usize; const COMPLEX_PER_VECTOR: usize; type ScalarType: SseNum; @@ -146,6 +145,9 @@ pub trait SseVector: Copy + Debug + Send + Sync { // stores of complex numbers unsafe fn store_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); + + // Keep this around even though it's unused - research went into how to do it, keeping it ensures that research doesn't need to be repeated + #[allow(unused)] unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); // math ops @@ -181,7 +183,6 @@ pub trait SseVector: Copy + Debug + Send + Sync { } impl SseVector for __m128 { - const SCALAR_PER_VECTOR: usize = 4; const COMPLEX_PER_VECTOR: usize = 2; type ScalarType = f32; @@ -308,7 +309,6 @@ impl SseVector for __m128 { } impl SseVector for __m128d { - const SCALAR_PER_VECTOR: usize = 2; const COMPLEX_PER_VECTOR: usize = 1; type ScalarType = f64; @@ -511,8 +511,6 @@ pub trait SseArrayMut: SseArray + DerefMut { unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize); // Store the low complex number from a SSE vector to the array. unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize); - // Store the high complex number from a SSE vector to the array. - unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize); } impl SseArrayMut for &mut [Complex] { @@ -521,12 +519,6 @@ impl SseArrayMut for &mut [Complex] { debug_assert!(self.len() >= index + S::VectorType::COMPLEX_PER_VECTOR); S::VectorType::store_complex(self.as_mut_ptr().add(index), vector) } - - #[inline(always)] - unsafe fn store_partial_hi_complex(&mut self, vector: S::VectorType, index: usize) { - debug_assert!(self.len() >= index + 1); - S::VectorType::store_partial_hi_complex(self.as_mut_ptr().add(index), vector) - } #[inline(always)] unsafe fn store_partial_lo_complex(&mut self, vector: S::VectorType, index: usize) { debug_assert!(self.len() >= index + 1); @@ -547,8 +539,4 @@ where unsafe fn store_partial_lo_complex(&mut self, vector: T::VectorType, index: usize) { self.output.store_partial_lo_complex(vector, index); } - #[inline(always)] - unsafe fn store_partial_hi_complex(&mut self, vector: T::VectorType, index: usize) { - self.output.store_partial_hi_complex(vector, index); - } } diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 231c434..8b5b326 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -3,10 +3,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -25,113 +22,88 @@ unsafe fn pack_64(a: Complex) -> v128 { #[allow(unused)] macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { - ($struct_name:ident) => { - impl $struct_name { - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); - } - - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_parallel_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_parallel_fft_contiguous(workaround_transmute_mut::<_, Complex>( - buffer, - )); - } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_fft_butterfly_multi( + ($struct_name:ident, $len:expr, $direction_fn:expr) => { + impl Fft for $struct_name { + fn process_immutable_with_scratch( &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - let len = buffer.len(); - let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { - self.perform_parallel_fft_butterfly(chunk) - }); - if alldone.is_err() && buffer.len() >= self.len() { - self.perform_fft_butterfly(&mut buffer[len - self.len()..]); + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::wasm_simd_common::wasm_simd_fft_helper_immut_unroll2x( + simd_input, + simd_output, + self.len(), + |input, output| { + self.perform_parallel_fft_contiguous(DoubleBuf { input, output }) + }, + |input, output| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } - Ok(()) } - - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( + fn process_outofplace_with_scratch( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], - ) -> Result<(), ()> { - let len = input.len(); - let alldone = array_utils::iter_chunks_zipped( - input, - output, - 2 * self.len(), - |in_chunk, out_chunk| { - let input_slice = crate::array_utils::workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_parallel_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }, - ); - if alldone.is_err() && input.len() >= self.len() { - let input_slice = crate::array_utils::workaround_transmute(input); - let output_slice = workaround_transmute_mut(output); - self.perform_fft_contiguous(DoubleBuf { - input: &input_slice[len - self.len()..], - output: &mut output_slice[len - self.len()..], - }) + _scratch: &mut [Complex], + ) { + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::wasm_simd_common::wasm_simd_fft_helper_outofplace_unroll2x( + simd_input, + simd_output, + self.len(), + |input, output| { + self.perform_parallel_fft_contiguous(DoubleBuf { input, output }) + }, + |input, output| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } - Ok(()) } - } - }; -} - -macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { - ($struct_name:ident) => { - impl $struct_name { - // Do a single fft - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { - self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); + fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::wasm_simd_common::wasm_simd_fft_helper_inplace_unroll2x( + simd_buffer, + self.len(), + |chunk| self.perform_parallel_fft_contiguous(chunk), + |chunk| self.perform_fft_contiguous(chunk), + ) + } } - - // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_fft_butterfly_multi( - &self, - buffer: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_butterfly(chunk) - }) + #[inline(always)] + fn get_inplace_scratch_len(&self) -> usize { + 0 } - - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - #[target_feature(enable = "simd128")] - pub(crate) unsafe fn perform_oop_fft_butterfly_multi( - &self, - input: &[Complex], - output: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = crate::array_utils::workaround_transmute(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }) + #[inline(always)] + fn get_outofplace_scratch_len(&self) -> usize { + 0 + } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } + } + impl Length for $struct_name { + #[inline(always)] + fn len(&self) -> usize { + $len + } + } + impl Direction for $struct_name { + #[inline(always)] + fn fft_direction(&self) -> FftDirection { + $direction_fn(self) } } }; } -#[allow(unused)] -macro_rules! boilerplate_fft_wasm_simd_common_butterfly { +macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { fn process_immutable_with_scratch( @@ -140,17 +112,17 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::wasm_simd_common::wasm_simd_fft_helper_immut( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } } fn process_outofplace_with_scratch( @@ -159,32 +131,29 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::wasm_simd_common::wasm_simd_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], + self.len(), + 0, + |input, output, _| self.perform_fft_contiguous(DoubleBuf { input, output }), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { - if buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let result = unsafe { self.perform_fft_butterfly_multi(buffer) }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace(self.len(), buffer.len(), 0, 0); + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + super::wasm_simd_common::wasm_simd_fft_helper_inplace( + simd_buffer, + &mut [], + self.len(), + 0, + |chunk, _| self.perform_fft_contiguous(chunk), + ) } } #[inline(always)] @@ -214,7 +183,6 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { } }; } - // _ _________ _ _ _ // / | |___ /___ \| |__ (_) |_ // | | _____ |_ \ __) | '_ \| | __| @@ -227,8 +195,7 @@ pub struct WasmSimdF32Butterfly1 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly1); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly1, 1, |this: &WasmSimdF32Butterfly1<_>| this.direction @@ -270,8 +237,7 @@ pub struct WasmSimdF64Butterfly1 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly1); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly1, 1, |this: &WasmSimdF64Butterfly1<_>| this.direction @@ -304,8 +270,7 @@ pub struct WasmSimdF32Butterfly2 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly2); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly2, 2, |this: &WasmSimdF32Butterfly2<_>| this.direction @@ -402,8 +367,7 @@ pub struct WasmSimdF64Butterfly2 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly2); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly2, 2, |this: &WasmSimdF64Butterfly2<_>| this.direction @@ -458,8 +422,7 @@ pub struct WasmSimdF32Butterfly3 { twiddle1im: v128, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly3); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly3, 3, |this: &WasmSimdF32Butterfly3<_>| this.direction @@ -574,8 +537,7 @@ pub struct WasmSimdF64Butterfly3 { twiddle1im: v128, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly3); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly3, 3, |this: &WasmSimdF64Butterfly3<_>| this.direction @@ -650,8 +612,7 @@ pub struct WasmSimdF32Butterfly4 { rotate: Rotate90F32, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly4); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly4, 4, |this: &WasmSimdF32Butterfly4<_>| this.direction @@ -767,8 +728,7 @@ pub struct WasmSimdF64Butterfly4 { rotate: Rotate90F64, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly4); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly4, 4, |this: &WasmSimdF64Butterfly4<_>| this.direction @@ -851,8 +811,7 @@ pub struct WasmSimdF32Butterfly5 { twiddle2im: v128, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly5); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly5, 5, |this: &WasmSimdF32Butterfly5<_>| this.direction @@ -1020,8 +979,7 @@ pub struct WasmSimdF64Butterfly5 { twiddle2im: v128, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly5); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly5, 5, |this: &WasmSimdF64Butterfly5<_>| this.direction @@ -1119,16 +1077,14 @@ impl WasmSimdF64Butterfly5 { // pub struct WasmSimdF32Butterfly6 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF32Butterfly3, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly6); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly6, 6, - |this: &WasmSimdF32Butterfly6<_>| this.direction + |this: &WasmSimdF32Butterfly6<_>| this.bf3.direction ); impl WasmSimdF32Butterfly6 { #[inline(always)] @@ -1137,7 +1093,6 @@ impl WasmSimdF32Butterfly6 { let bf3 = WasmSimdF32Butterfly3::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, } @@ -1239,16 +1194,14 @@ impl WasmSimdF32Butterfly6 { // pub struct WasmSimdF64Butterfly6 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF64Butterfly3, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly6); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly6, 6, - |this: &WasmSimdF64Butterfly6<_>| this.direction + |this: &WasmSimdF64Butterfly6<_>| this.bf3.direction ); impl WasmSimdF64Butterfly6 { #[inline(always)] @@ -1257,7 +1210,6 @@ impl WasmSimdF64Butterfly6 { let bf3 = WasmSimdF64Butterfly3::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, } @@ -1312,16 +1264,14 @@ impl WasmSimdF64Butterfly6 { pub struct WasmSimdF32Butterfly8 { root2: v128, root2_dual: v128, - direction: FftDirection, bf4: WasmSimdF32Butterfly4, rotate90: Rotate90F32, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly8); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly8, 8, - |this: &WasmSimdF32Butterfly8<_>| this.direction + |this: &WasmSimdF32Butterfly8<_>| this.bf4.direction ); impl WasmSimdF32Butterfly8 { #[inline(always)] @@ -1338,7 +1288,6 @@ impl WasmSimdF32Butterfly8 { Self { root2, root2_dual, - direction, bf4, rotate90, } @@ -1447,16 +1396,14 @@ impl WasmSimdF32Butterfly8 { pub struct WasmSimdF64Butterfly8 { root2: v128, - direction: FftDirection, bf4: WasmSimdF64Butterfly4, rotate90: Rotate90F64, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly8); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly8, 8, - |this: &WasmSimdF64Butterfly8<_>| this.direction + |this: &WasmSimdF64Butterfly8<_>| this.bf4.direction ); impl WasmSimdF64Butterfly8 { #[inline(always)] @@ -1471,7 +1418,6 @@ impl WasmSimdF64Butterfly8 { }; Self { root2, - direction, bf4, rotate90, } @@ -1530,7 +1476,6 @@ impl WasmSimdF64Butterfly8 { // /_/ |____/_____|_.__/|_|\__| // pub struct WasmSimdF32Butterfly9 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF32Butterfly3, twiddle1: v128, @@ -1538,11 +1483,10 @@ pub struct WasmSimdF32Butterfly9 { twiddle4: v128, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly9); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly9, 9, - |this: &WasmSimdF32Butterfly9<_>| this.direction + |this: &WasmSimdF32Butterfly9<_>| this.bf3.direction ); impl WasmSimdF32Butterfly9 { #[inline(always)] @@ -1557,7 +1501,6 @@ impl WasmSimdF32Butterfly9 { let twiddle4 = f32x4(tw4.re, tw4.im, tw4.re, tw4.im); Self { - direction, _phantom: std::marker::PhantomData, bf3, twiddle1, @@ -1659,7 +1602,6 @@ impl WasmSimdF32Butterfly9 { // pub struct WasmSimdF64Butterfly9 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF64Butterfly3, twiddle1: v128, @@ -1667,11 +1609,10 @@ pub struct WasmSimdF64Butterfly9 { twiddle4: v128, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly9); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly9, 9, - |this: &WasmSimdF64Butterfly9<_>| this.direction + |this: &WasmSimdF64Butterfly9<_>| this.bf3.direction ); impl WasmSimdF64Butterfly9 { #[inline(always)] @@ -1686,7 +1627,6 @@ impl WasmSimdF64Butterfly9 { let twiddle4 = f64x2(tw4.re, tw4.im); Self { - direction, _phantom: std::marker::PhantomData, bf3, twiddle1, @@ -1737,16 +1677,14 @@ impl WasmSimdF64Butterfly9 { // pub struct WasmSimdF32Butterfly10 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf5: WasmSimdF32Butterfly5, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly10); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly10, 10, - |this: &WasmSimdF32Butterfly10<_>| this.direction + |this: &WasmSimdF32Butterfly10<_>| this.bf5.direction ); impl WasmSimdF32Butterfly10 { #[inline(always)] @@ -1754,7 +1692,6 @@ impl WasmSimdF32Butterfly10 { assert_f32::(); let bf5 = WasmSimdF32Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf5, } @@ -1854,17 +1791,15 @@ impl WasmSimdF32Butterfly10 { // pub struct WasmSimdF64Butterfly10 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf2: WasmSimdF64Butterfly2, bf5: WasmSimdF64Butterfly5, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly10); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly10, 10, - |this: &WasmSimdF64Butterfly10<_>| this.direction + |this: &WasmSimdF64Butterfly10<_>| this.bf5.direction ); impl WasmSimdF64Butterfly10 { #[inline(always)] @@ -1873,7 +1808,6 @@ impl WasmSimdF64Butterfly10 { let bf2 = WasmSimdF64Butterfly2::new(direction); let bf5 = WasmSimdF64Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf2, bf5, @@ -1926,17 +1860,15 @@ impl WasmSimdF64Butterfly10 { // pub struct WasmSimdF32Butterfly12 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF32Butterfly3, bf4: WasmSimdF32Butterfly4, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly12); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly12, 12, - |this: &WasmSimdF32Butterfly12<_>| this.direction + |this: &WasmSimdF32Butterfly12<_>| this.bf4.direction ); impl WasmSimdF32Butterfly12 { #[inline(always)] @@ -1945,7 +1877,6 @@ impl WasmSimdF32Butterfly12 { let bf3 = WasmSimdF32Butterfly3::new(direction); let bf4 = WasmSimdF32Butterfly4::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf4, @@ -2063,17 +1994,15 @@ impl WasmSimdF32Butterfly12 { // pub struct WasmSimdF64Butterfly12 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF64Butterfly3, bf4: WasmSimdF64Butterfly4, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly12); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly12, 12, - |this: &WasmSimdF64Butterfly12<_>| this.direction + |this: &WasmSimdF64Butterfly12<_>| this.bf4.direction ); impl WasmSimdF64Butterfly12 { #[inline(always)] @@ -2082,7 +2011,6 @@ impl WasmSimdF64Butterfly12 { let bf3 = WasmSimdF64Butterfly3::new(direction); let bf4 = WasmSimdF64Butterfly4::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf4, @@ -2135,17 +2063,15 @@ impl WasmSimdF64Butterfly12 { // |_|____/ |____/_____|_.__/|_|\__| // pub struct WasmSimdF32Butterfly15 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF32Butterfly3, bf5: WasmSimdF32Butterfly5, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly15); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly15, 15, - |this: &WasmSimdF32Butterfly15<_>| this.direction + |this: &WasmSimdF32Butterfly15<_>| this.bf3.direction ); impl WasmSimdF32Butterfly15 { #[inline(always)] @@ -2154,7 +2080,6 @@ impl WasmSimdF32Butterfly15 { let bf3 = WasmSimdF32Butterfly3::new(direction); let bf5 = WasmSimdF32Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf5, @@ -2271,17 +2196,15 @@ impl WasmSimdF32Butterfly15 { // pub struct WasmSimdF64Butterfly15 { - direction: FftDirection, _phantom: std::marker::PhantomData, bf3: WasmSimdF64Butterfly3, bf5: WasmSimdF64Butterfly5, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly15); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly15, 15, - |this: &WasmSimdF64Butterfly15<_>| this.direction + |this: &WasmSimdF64Butterfly15<_>| this.bf3.direction ); impl WasmSimdF64Butterfly15 { #[inline(always)] @@ -2290,7 +2213,6 @@ impl WasmSimdF64Butterfly15 { let bf3 = WasmSimdF64Butterfly3::new(direction); let bf5 = WasmSimdF64Butterfly5::new(direction); Self { - direction, _phantom: std::marker::PhantomData, bf3, bf5, @@ -2353,8 +2275,7 @@ pub struct WasmSimdF32Butterfly16 { twiddle9: v128, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly16); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly16, 16, |this: &WasmSimdF32Butterfly16<_>| this.bf4.direction @@ -2538,8 +2459,7 @@ pub struct WasmSimdF64Butterfly16 { twiddle9: v128, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly16); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly16, 16, |this: &WasmSimdF64Butterfly16<_>| this.bf4.direction @@ -2649,8 +2569,7 @@ pub struct WasmSimdF32Butterfly24 { twiddle10: v128, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly24); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly24, 24, |this: &WasmSimdF32Butterfly24<_>| this.bf4.direction @@ -2871,8 +2790,7 @@ pub struct WasmSimdF64Butterfly24 { twiddle10: v128, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly24); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly24, 24, |this: &WasmSimdF64Butterfly24<_>| this.bf4.direction @@ -3006,8 +2924,7 @@ pub struct WasmSimdF32Butterfly32 { twiddle21: v128, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly32); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly32, 32, |this: &WasmSimdF32Butterfly32<_>| this.bf8.bf4.direction @@ -3270,8 +3187,7 @@ pub struct WasmSimdF64Butterfly32 { twiddle21: v128, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly32); -boilerplate_fft_wasm_simd_common_butterfly!( +boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly32, 32, |this: &WasmSimdF64Butterfly32<_>| this.bf8.bf4.direction diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index 0b46f4c..609f20e 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -1,5 +1,10 @@ use std::any::TypeId; +use crate::fft_helper::{ + fft_helper_immut, fft_helper_immut_unroll2x, fft_helper_inplace, fft_helper_inplace_unroll2x, + fft_helper_outofplace, fft_helper_outofplace_unroll2x, +}; + /// Helper function to assert we have the right float type pub fn assert_f32() { let id_f32 = TypeId::of::(); @@ -63,29 +68,17 @@ macro_rules! boilerplate_fft_wasm_simd_oop { output: &mut [Complex], _scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here - } - - let result = unsafe { - array_utils::iter_chunks_zipped( - input, - output, + unsafe { + let simd_input = crate::array_utils::workaround_transmute(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::wasm_simd_common::wasm_simd_fft_helper_immut( + simd_input, + simd_output, + &mut [], self.len(), - |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), - ) - }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + 0, + |input, output, _| self.perform_fft_immut(input, output, &mut []), + ); } } fn process_outofplace_with_scratch( @@ -94,66 +87,33 @@ macro_rules! boilerplate_fft_wasm_simd_oop { output: &mut [Complex], _scratch: &mut [Complex], ) { - if self.len() == 0 { - return; - } - - if input.len() < self.len() || output.len() != input.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here - } - - let result = unsafe { - array_utils::iter_chunks_zipped_mut( - input, - output, + unsafe { + let simd_input = crate::array_utils::workaround_transmute_mut(input); + let simd_output = crate::array_utils::workaround_transmute_mut(output); + super::wasm_simd_common::wasm_simd_fft_helper_outofplace( + simd_input, + simd_output, + &mut [], self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) - }, - ) - }; - - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + 0, + |input, output, _| self.perform_fft_out_of_place(input, output, &mut []), + ); } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { - if self.len() == 0 { - return; - } - - let required_scratch = self.get_inplace_scratch_len(); - if scratch.len() < required_scratch || buffer.len() < self.len() { - // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( - self.len(), - buffer.len(), - self.get_inplace_scratch_len(), - scratch.len(), - ); - return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here - } - - let scratch = &mut scratch[..required_scratch]; - let result = unsafe { - array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { - self.perform_fft_out_of_place(chunk, scratch, &mut []); - chunk.copy_from_slice(scratch); - }) - }; - if result.is_err() { - // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, - // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_inplace( + unsafe { + let simd_buffer = crate::array_utils::workaround_transmute_mut(buffer); + let simd_scratch = crate::array_utils::workaround_transmute_mut(scratch); + super::wasm_simd_common::wasm_simd_fft_helper_inplace( + simd_buffer, + simd_scratch, self.len(), - buffer.len(), self.get_inplace_scratch_len(), - scratch.len(), - ); + |chunk, scratch| { + self.perform_fft_out_of_place(chunk, scratch, &mut []); + chunk.copy_from_slice(scratch); + }, + ) } } #[inline(always)] @@ -166,7 +126,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { } #[inline(always)] fn get_immutable_scratch_len(&self) -> usize { - self.len() + 0 } } impl Length for $struct_name { @@ -183,3 +143,96 @@ macro_rules! boilerplate_fft_wasm_simd_oop { } }; } + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the Wasm SIMD target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "simd128")] +pub unsafe fn wasm_simd_fft_helper_immut( + input: &[T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&[T], &mut [T], &mut [T]), +) { + fft_helper_immut( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the Wasm SIMD target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "simd128")] +pub unsafe fn wasm_simd_fft_helper_outofplace( + input: &mut [T], + output: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T], &mut [T]), +) { + fft_helper_outofplace( + input, + output, + scratch, + chunk_size, + required_scratch, + chunk_fn, + ) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the Wasm SIMD target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "simd128")] +pub unsafe fn wasm_simd_fft_helper_inplace( + buffer: &mut [T], + scratch: &mut [T], + chunk_size: usize, + required_scratch: usize, + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_inplace(buffer, scratch, chunk_size, required_scratch, chunk_fn) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the Wasm SIMD target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "simd128")] +pub unsafe fn wasm_simd_fft_helper_immut_unroll2x( + input: &[T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&[T], &mut [T]), + chunk_fn: impl FnMut(&[T], &mut [T]), +) { + fft_helper_immut_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the Wasm SIMD target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "simd128")] +pub unsafe fn wasm_simd_fft_helper_outofplace_unroll2x( + input: &mut [T], + output: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T], &mut [T]), + chunk_fn: impl FnMut(&mut [T], &mut [T]), +) { + fft_helper_outofplace_unroll2x(input, output, chunk_size, chunk2x_fn, chunk_fn) +} + +// A wrapper for the FFT helper functions that make sure the entire thing happens with the benefit of the Wasm SIMD target feature, +// so that things like loading twiddle factor registers etc can be lifted out of the loop +#[target_feature(enable = "simd128")] +pub unsafe fn wasm_simd_fft_helper_inplace_unroll2x( + buffer: &mut [T], + chunk_size: usize, + chunk2x_fn: impl FnMut(&mut [T]), + chunk_fn: impl FnMut(&mut [T]), +) { + fft_helper_inplace_unroll2x(buffer, chunk_size, chunk2x_fn, chunk_fn) +} diff --git a/src/wasm_simd/wasm_simd_prime_butterflies.rs b/src/wasm_simd/wasm_simd_prime_butterflies.rs index 4d46e6b..11a775f 100644 --- a/src/wasm_simd/wasm_simd_prime_butterflies.rs +++ b/src/wasm_simd/wasm_simd_prime_butterflies.rs @@ -5,10 +5,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -88,8 +85,7 @@ struct WasmSimdF32Butterfly7 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly7); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly7, 7, |this: &WasmSimdF32Butterfly7<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly7, 7, |this: &WasmSimdF32Butterfly7<_>| this.direction); impl WasmSimdF32Butterfly7 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -193,8 +189,7 @@ struct WasmSimdF64Butterfly7 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly7); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly7, 7, |this: &WasmSimdF64Butterfly7<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly7, 7, |this: &WasmSimdF64Butterfly7<_>| this.direction); impl WasmSimdF64Butterfly7 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -269,8 +264,7 @@ struct WasmSimdF32Butterfly11 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly11); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly11, 11, |this: &WasmSimdF32Butterfly11<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly11, 11, |this: &WasmSimdF32Butterfly11<_>| this.direction); impl WasmSimdF32Butterfly11 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -424,8 +418,7 @@ struct WasmSimdF64Butterfly11 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly11); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly11, 11, |this: &WasmSimdF64Butterfly11<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly11, 11, |this: &WasmSimdF64Butterfly11<_>| this.direction); impl WasmSimdF64Butterfly11 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -542,8 +535,7 @@ struct WasmSimdF32Butterfly13 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly13); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly13, 13, |this: &WasmSimdF32Butterfly13<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly13, 13, |this: &WasmSimdF32Butterfly13<_>| this.direction); impl WasmSimdF32Butterfly13 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -728,8 +720,7 @@ struct WasmSimdF64Butterfly13 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly13); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly13, 13, |this: &WasmSimdF64Butterfly13<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly13, 13, |this: &WasmSimdF64Butterfly13<_>| this.direction); impl WasmSimdF64Butterfly13 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -873,8 +864,7 @@ struct WasmSimdF32Butterfly17 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly17); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly17, 17, |this: &WasmSimdF32Butterfly17<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly17, 17, |this: &WasmSimdF32Butterfly17<_>| this.direction); impl WasmSimdF32Butterfly17 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -1133,8 +1123,7 @@ struct WasmSimdF64Butterfly17 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly17); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly17, 17, |this: &WasmSimdF64Butterfly17<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly17, 17, |this: &WasmSimdF64Butterfly17<_>| this.direction); impl WasmSimdF64Butterfly17 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -1344,8 +1333,7 @@ struct WasmSimdF32Butterfly19 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly19); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly19, 19, |this: &WasmSimdF32Butterfly19<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly19, 19, |this: &WasmSimdF32Butterfly19<_>| this.direction); impl WasmSimdF32Butterfly19 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -1647,8 +1635,7 @@ struct WasmSimdF64Butterfly19 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly19); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly19, 19, |this: &WasmSimdF64Butterfly19<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly19, 19, |this: &WasmSimdF64Butterfly19<_>| this.direction); impl WasmSimdF64Butterfly19 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -1897,8 +1884,7 @@ struct WasmSimdF32Butterfly23 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly23); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly23, 23, |this: &WasmSimdF32Butterfly23<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly23, 23, |this: &WasmSimdF32Butterfly23<_>| this.direction); impl WasmSimdF32Butterfly23 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -2298,8 +2284,7 @@ struct WasmSimdF64Butterfly23 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly23); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly23, 23, |this: &WasmSimdF64Butterfly23<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly23, 23, |this: &WasmSimdF64Butterfly23<_>| this.direction); impl WasmSimdF64Butterfly23 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -2638,8 +2623,7 @@ struct WasmSimdF32Butterfly29 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly29); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly29, 29, |this: &WasmSimdF32Butterfly29<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly29, 29, |this: &WasmSimdF32Butterfly29<_>| this.direction); impl WasmSimdF32Butterfly29 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -3216,8 +3200,7 @@ struct WasmSimdF64Butterfly29 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly29); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly29, 29, |this: &WasmSimdF64Butterfly29<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly29, 29, |this: &WasmSimdF64Butterfly29<_>| this.direction); impl WasmSimdF64Butterfly29 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -3721,8 +3704,7 @@ struct WasmSimdF32Butterfly31 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly31); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF32Butterfly31, 31, |this: &WasmSimdF32Butterfly31<_>| this.direction); +boilerplate_fft_wasm_simd_f32_butterfly!(WasmSimdF32Butterfly31, 31, |this: &WasmSimdF32Butterfly31<_>| this.direction); impl WasmSimdF32Butterfly31 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] @@ -4366,8 +4348,7 @@ struct WasmSimdF64Butterfly31 { _phantom: std::marker::PhantomData, } -boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly31); -boilerplate_fft_wasm_simd_common_butterfly!(WasmSimdF64Butterfly31, 31, |this: &WasmSimdF64Butterfly31<_>| this.direction); +boilerplate_fft_wasm_simd_f64_butterfly!(WasmSimdF64Butterfly31, 31, |this: &WasmSimdF64Butterfly31<_>| this.direction); impl WasmSimdF64Butterfly31 { /// Safety: The current machine must support the simd128 instruction set #[target_feature(enable = "simd128")] diff --git a/src/wasm_simd/wasm_simd_radix4.rs b/src/wasm_simd/wasm_simd_radix4.rs index 7368257..1a34672 100644 --- a/src/wasm_simd/wasm_simd_radix4.rs +++ b/src/wasm_simd/wasm_simd_radix4.rs @@ -3,8 +3,7 @@ use num_complex::Complex; use std::any::TypeId; use std::sync::Arc; -use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{bitreversed_transpose, workaround_transmute_mut}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/wasm_simd/wasm_simd_vector.rs b/src/wasm_simd/wasm_simd_vector.rs index 4d0d27a..2d3d26b 100644 --- a/src/wasm_simd/wasm_simd_vector.rs +++ b/src/wasm_simd/wasm_simd_vector.rs @@ -237,7 +237,6 @@ pub struct Rotation90(V); // A trait to hold the BVectorType and COMPLEX_PER_VECTOR associated data pub trait WasmVector: Copy + Debug + Send + Sync { - const SCALAR_PER_VECTOR: usize; const COMPLEX_PER_VECTOR: usize; type ScalarType: WasmNum; @@ -252,6 +251,9 @@ pub trait WasmVector: Copy + Debug + Send + Sync { // stores of complex numbers unsafe fn store_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); + + // Keep this around even though it's unused - research went into how to do it, keeping it ensures that research doesn't need to be repeated + #[allow(unused)] unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); // math ops @@ -287,7 +289,6 @@ pub trait WasmVector: Copy + Debug + Send + Sync { } impl WasmVector for WasmVector32 { - const SCALAR_PER_VECTOR: usize = 4; const COMPLEX_PER_VECTOR: usize = 2; type ScalarType = f32; @@ -423,7 +424,6 @@ impl WasmVector for WasmVector32 { } impl WasmVector for WasmVector64 { - const SCALAR_PER_VECTOR: usize = 2; const COMPLEX_PER_VECTOR: usize = 1; type ScalarType = f64; @@ -559,22 +559,22 @@ impl WasmVector for WasmVector64 { } } -// A trait to handle reading from an array of complex floats into SSE vectors. -// SSE works with 128-bit vectors, meaning a vector can hold two complex f32, +// A trait to handle reading from an array of complex floats into Wasm SIMD vectors. +// Wasm SIMD works with 128-bit vectors, meaning a vector can hold two complex f32, // or a single complex f64. pub trait WasmSimdArray: Deref { - // Load complex numbers from the array to fill a SSE vector. + // Load complex numbers from the array to fill a Wasm SIMD vector. unsafe fn load_complex_v128(&self, index: usize) -> v128; - // Load a single complex number from the array into a SSE vector, setting the unused elements to zero. + // Load a single complex number from the array into a Wasm SIMD vector, setting the unused elements to zero. unsafe fn load_partial_lo_complex_v128(&self, index: usize) -> v128; - // Load a single complex number from the array, and copy it to all elements of a SSE vector. + // Load a single complex number from the array, and copy it to all elements of a Wasm SIMD vector. unsafe fn load1_complex_v128(&self, index: usize) -> v128; - // Load complex numbers from the array to fill a SSE vector. + // Load complex numbers from the array to fill a Wasm SIMD vector. unsafe fn load_complex(&self, index: usize) -> S::VectorType; - // Load a single complex number from the array into a SSE vector, setting the unused elements to zero. + // Load a single complex number from the array into a Wasm SIMD vector, setting the unused elements to zero. unsafe fn load_partial_lo_complex(&self, index: usize) -> S::VectorType; - // Load a single complex number from the array, and copy it to all elements of a SSE vector. + // Load a single complex number from the array, and copy it to all elements of a Wasm SIMD vector. unsafe fn load1_complex(&self, index: usize) -> S::VectorType; } @@ -684,16 +684,14 @@ where } } -// A trait to handle writing to an array of complex floats from SSE vectors. -// SSE works with 128-bit vectors, meaning a vector can hold two complex f32, +// A trait to handle writing to an array of complex floats from Wasm SIMD vectors. +// Wasm SIMD works with 128-bit vectors, meaning a vector can hold two complex f32, // or a single complex f64. pub trait WasmSimdArrayMut: WasmSimdArray + DerefMut { // Store all complex numbers from a SSE vector to the array. unsafe fn store_complex_v128(&mut self, vector: v128, index: usize); // Store the low complex number from a SSE vector to the array. unsafe fn store_partial_lo_complex_v128(&mut self, vector: v128, index: usize); - // Store the high complex number from a SSE vector to the array. - unsafe fn store_partial_hi_complex_v128(&mut self, vector: v128, index: usize); // Store all complex numbers from a SSE vector to the array. unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize); @@ -712,11 +710,6 @@ impl WasmSimdArrayMut for &mut [Complex] { debug_assert!(self.len() >= index + 1); S::VectorType::store_partial_lo_complex(self.as_mut_ptr().add(index), S::wrap(vector)) } - #[inline(always)] - unsafe fn store_partial_hi_complex_v128(&mut self, vector: v128, index: usize) { - debug_assert!(self.len() >= index + 1); - S::VectorType::store_partial_hi_complex(self.as_mut_ptr().add(index), S::wrap(vector)) - } #[inline(always)] unsafe fn store_complex(&mut self, vector: S::VectorType, index: usize) { @@ -743,10 +736,6 @@ where unsafe fn store_partial_lo_complex_v128(&mut self, vector: v128, index: usize) { self.output.store_partial_lo_complex_v128(vector, index); } - #[inline(always)] - unsafe fn store_partial_hi_complex_v128(&mut self, vector: v128, index: usize) { - self.output.store_partial_hi_complex_v128(vector, index); - } #[inline(always)] unsafe fn store_complex(&mut self, vector: T::VectorType, index: usize) { diff --git a/tests/accuracy.rs b/tests/accuracy.rs index a268cff..a45490e 100644 --- a/tests/accuracy.rs +++ b/tests/accuracy.rs @@ -16,6 +16,7 @@ use rustfft::{num_traits::Zero, FftDirection}; use rand::distributions::{uniform::SampleUniform, Distribution, Uniform}; use rand::{rngs::StdRng, SeedableRng}; +use wasm_bindgen_test::wasm_bindgen_test; /// The seed for the random number generator used to generate /// random signals. It's defined here so that we have deterministic @@ -37,7 +38,6 @@ fn compare_vectors(vec1: &[Complex], vec2: &[Comp fn fft_matches_control(control: Arc>, input: &[Complex]) -> bool { let mut control_input = input.to_vec(); - let mut test_input = input.to_vec(); let mut planner = FftPlanner::new(); let fft = planner.plan_fft(control.len(), control.fft_direction()); @@ -54,14 +54,31 @@ fn fft_matches_control(control: Arc>, input: &[Com let scratch_max = std::cmp::max( control.get_inplace_scratch_len(), - fft.get_inplace_scratch_len(), + std::cmp::max( + fft.get_inplace_scratch_len(), + std::cmp::max( + fft.get_outofplace_scratch_len(), + fft.get_immutable_scratch_len(), + ), + ), ); let mut scratch = vec![Zero::zero(); scratch_max]; control.process_with_scratch(&mut control_input, &mut scratch); - fft.process_with_scratch(&mut test_input, &mut scratch); - return compare_vectors(&test_input, &control_input); + let mut test_output_inplace = input.to_vec(); + fft.process_with_scratch(&mut test_output_inplace, &mut scratch); + + let mut input_oop = input.to_vec(); + let mut test_output_oop = input.to_vec(); + fft.process_outofplace_with_scratch(&mut input_oop, &mut test_output_oop, &mut scratch); + + let mut test_output_immut = input.to_vec(); + fft.process_immutable_with_scratch(input, &mut test_output_immut, &mut scratch); + + return compare_vectors(&control_input, &test_output_inplace) + && compare_vectors(&control_input, &test_output_oop) + && compare_vectors(&control_input, &test_output_immut); } fn random_signal(length: usize) -> Vec> { @@ -168,3 +185,64 @@ fn test_planned_fft_inverse_f64() { assert!(fft_matches_control(control, &signal), "length = {}", len); } } + +#[wasm_bindgen_test] +fn wasm_test_planned_fft_forward_f32() { + let direction = FftDirection::Forward; + let cache: ControlCache = ControlCache::new(TEST_MAX, direction); + + for len in 1..TEST_MAX { + println!("len: {len}"); + let control = cache.plan_fft(len); + assert_eq!(control.len(), len); + assert_eq!(control.fft_direction(), direction); + + let signal = random_signal(len); + assert!(fft_matches_control(control, &signal), "length = {}", len); + } +} + +#[wasm_bindgen_test] +fn wasm_test_planned_fft_inverse_f32() { + let direction = FftDirection::Inverse; + let cache: ControlCache = ControlCache::new(TEST_MAX, direction); + + for len in 1..TEST_MAX { + let control = cache.plan_fft(len); + assert_eq!(control.len(), len); + assert_eq!(control.fft_direction(), direction); + + let signal = random_signal(len); + assert!(fft_matches_control(control, &signal), "length = {}", len); + } +} + +#[wasm_bindgen_test] +fn wasm_test_planned_fft_forward_f64() { + let direction = FftDirection::Forward; + let cache: ControlCache = ControlCache::new(TEST_MAX, direction); + + for len in 1..TEST_MAX { + let control = cache.plan_fft(len); + assert_eq!(control.len(), len); + assert_eq!(control.fft_direction(), direction); + + let signal = random_signal(len); + assert!(fft_matches_control(control, &signal), "length = {}", len); + } +} + +#[wasm_bindgen_test] +fn wasm_test_planned_fft_inverse_f64() { + let direction = FftDirection::Inverse; + let cache: ControlCache = ControlCache::new(TEST_MAX, direction); + + for len in 1..TEST_MAX { + let control = cache.plan_fft(len); + assert_eq!(control.len(), len); + assert_eq!(control.fft_direction(), direction); + + let signal = random_signal(len); + assert!(fft_matches_control(control, &signal), "length = {}", len); + } +} diff --git a/tests/test_immutable.rs b/tests/test_immutable.rs deleted file mode 100644 index 029eca5..0000000 --- a/tests/test_immutable.rs +++ /dev/null @@ -1,52 +0,0 @@ -use num_complex::Complex; -use rustfft::{FftNum, FftPlanner}; - -const TEST_MAX: usize = 1001; - -#[test] -fn immutable_f32() { - for i in 0..TEST_MAX { - let input = vec![Complex::new(7.0, 8.0); i]; - - let mut_output = fft_wrapper_mut::(&input); - let immut_output = fft_wrapper_immut::(&input); - - assert_eq!(mut_output, immut_output, "{}", i); - } -} - -#[test] -fn immutable_f64() { - for i in 0..TEST_MAX { - let input = vec![Complex::new(7.0, 8.0); i]; - - let mut_output = fft_wrapper_mut::(&input); - let immut_output = fft_wrapper_immut::(&input); - - assert_eq!(mut_output, immut_output, "{}", i); - } -} - -fn fft_wrapper_mut(input: &[Complex]) -> Vec> { - let cz = Complex::new(T::zero(), T::zero()); - let mut plan = FftPlanner::::new(); - let p = plan.plan_fft_forward(input.len()); - - let mut scratch = vec![cz; p.get_inplace_scratch_len()]; - let mut output = input.to_vec(); - - p.process_with_scratch(&mut output, &mut scratch); - output -} - -fn fft_wrapper_immut(input: &[Complex]) -> Vec> { - let cz = Complex::new(T::zero(), T::zero()); - let mut plan = FftPlanner::::new(); - let p = plan.plan_fft_forward(input.len()); - - let mut scratch = vec![cz; p.get_immutable_scratch_len()]; - let mut output = vec![cz; input.len()]; - - p.process_immutable_with_scratch(input, &mut output, &mut scratch); - output -} diff --git a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs index 6937f5c..28b8fd5 100644 --- a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs +++ b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs @@ -5,10 +5,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; -use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -79,8 +76,7 @@ struct {{this.struct_name_32}} { _phantom: std::marker::PhantomData, } -boilerplate_fft_{{../arch.name_snakecase}}_f32_butterfly!({{this.struct_name_32}}); -boilerplate_fft_{{../arch.name_snakecase}}_common_butterfly!({{this.struct_name_32}}, {{this.len}}, |this: &{{this.struct_name_32}}<_>| this.direction); +boilerplate_fft_{{../arch.name_snakecase}}_f32_butterfly!({{this.struct_name_32}}, {{this.len}}, |this: &{{this.struct_name_32}}<_>| this.direction); impl {{this.struct_name_32}} { /// Safety: The current machine must support the {{../arch.cpu_feature_name}} instruction set #[target_feature(enable = "{{../arch.cpu_feature_name}}")] @@ -134,8 +130,7 @@ struct {{this.struct_name_64}} { _phantom: std::marker::PhantomData, } -boilerplate_fft_{{../arch.name_snakecase}}_f64_butterfly!({{this.struct_name_64}}); -boilerplate_fft_{{../arch.name_snakecase}}_common_butterfly!({{this.struct_name_64}}, {{this.len}}, |this: &{{this.struct_name_64}}<_>| this.direction); +boilerplate_fft_{{../arch.name_snakecase}}_f64_butterfly!({{this.struct_name_64}}, {{this.len}}, |this: &{{this.struct_name_64}}<_>| this.direction); impl {{this.struct_name_64}} { /// Safety: The current machine must support the {{../arch.cpu_feature_name}} instruction set #[target_feature(enable = "{{../arch.cpu_feature_name}}")]