diff --git a/benches/bench_rustfft.rs b/benches/bench_rustfft.rs index 299c72c..387c5f3 100644 --- a/benches/bench_rustfft.rs +++ b/benches/bench_rustfft.rs @@ -23,6 +23,13 @@ impl Fft for Noop { fn process_outofplace_with_scratch(&self, _input: &mut [Complex], _output: &mut [Complex], _scratch: &mut [Complex]) {} fn get_inplace_scratch_len(&self) -> usize { self.len } fn get_outofplace_scratch_len(&self) -> usize { 0 } + fn process_immutable_with_scratch( + &self, + _input: &[Complex], + _output: &mut [Complex], + _scratch: &mut [Complex], + ) {} + fn get_immutable_scratch_len(&self) -> usize { 0 } } impl Length for Noop { fn len(&self) -> usize { self.len } diff --git a/benches/bench_rustfft_scalar.rs b/benches/bench_rustfft_scalar.rs index b57a5d3..371aebf 100644 --- a/benches/bench_rustfft_scalar.rs +++ b/benches/bench_rustfft_scalar.rs @@ -33,6 +33,18 @@ impl Fft for Noop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + + fn process_immutable_with_scratch( + &self, + _input: &[Complex], + _output: &mut [Complex], + _scratch: &mut [Complex], + ) { + } + + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for Noop { fn len(&self) -> usize { diff --git a/src/algorithm/bluesteins_algorithm.rs b/src/algorithm/bluesteins_algorithm.rs index 9cc3642..315aa39 100644 --- a/src/algorithm/bluesteins_algorithm.rs +++ b/src/algorithm/bluesteins_algorithm.rs @@ -137,9 +137,10 @@ impl BluesteinsAlgorithm { } } - fn perform_fft_out_of_place( + #[inline] + fn perform_fft_immut( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], scratch: &mut [Complex], ) { @@ -179,6 +180,15 @@ impl BluesteinsAlgorithm { *buffer_entry = inner_entry.conj() * twiddle; } } + + fn perform_fft_out_of_place( + &self, + input: &mut [Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, scratch); + } } boilerplate_fft!( BluesteinsAlgorithm, @@ -186,7 +196,9 @@ boilerplate_fft!( |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len() + this.inner_fft.get_inplace_scratch_len(), // in-place scratch len |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len() - + this.inner_fft.get_inplace_scratch_len() // out of place scratch len + + this.inner_fft.get_inplace_scratch_len(), // out of place scratch len + |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len() + + this.inner_fft.get_inplace_scratch_len() // immut scratch len ); #[cfg(test)] diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 1292618..7bae332 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -3,7 +3,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; use crate::array_utils::{self, DoubleBuf, LoadStore}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -17,6 +17,39 @@ macro_rules! boilerplate_fft_butterfly { } } impl Fft for $struct_name { + #[inline] + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + self.perform_fft_butterfly(DoubleBuf { + input: in_chunk, + output: out_chunk, + }) + }; + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -26,7 +59,7 @@ macro_rules! boilerplate_fft_butterfly { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = array_utils::iter_chunks_zipped( @@ -56,7 +89,7 @@ macro_rules! boilerplate_fft_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| unsafe { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| unsafe { self.perform_fft_butterfly(chunk) }); @@ -74,6 +107,10 @@ macro_rules! boilerplate_fft_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -104,13 +141,22 @@ impl Butterfly1 { } } impl Fft for Butterfly1 { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + output.copy_from_slice(input); + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], output: &mut [Complex], _scratch: &mut [Complex], ) { - output.copy_from_slice(&input); + output.copy_from_slice(input); } fn process_with_scratch(&self, _buffer: &mut [Complex], _scratch: &mut [Complex]) {} @@ -122,6 +168,10 @@ impl Fft for Butterfly1 { fn get_outofplace_scratch_len(&self) -> usize { 0 } + + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for Butterfly1 { fn len(&self) -> usize { diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index e0b700c..95f781f 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -2,7 +2,7 @@ use num_complex::Complex; use num_traits::Zero; use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{twiddles, FftDirection}; use crate::{Direction, Fft, FftNum, Length}; @@ -45,7 +45,7 @@ impl Dft { 0 } - fn perform_fft_out_of_place( + fn perform_fft_immut( &self, signal: &[Complex], spectrum: &mut [Complex], @@ -68,8 +68,17 @@ impl Dft { } } } + + fn perform_fft_out_of_place( + &self, + signal: &[Complex], + spectrum: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(signal, spectrum, _scratch); + } } -boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len()); +boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len(), |_: &Dft<_>| 0); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/good_thomas_algorithm.rs b/src/algorithm/good_thomas_algorithm.rs index 005ba52..db4b591 100644 --- a/src/algorithm/good_thomas_algorithm.rs +++ b/src/algorithm/good_thomas_algorithm.rs @@ -50,6 +50,7 @@ pub struct GoodThomasAlgorithm { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, len: usize, direction: FftDirection, @@ -117,6 +118,11 @@ impl GoodThomasAlgorithm { height_outofplace_scratch, ); + let immut_scratch_len = max( + width_fft.get_inplace_scratch_len(), + len + height_fft.get_inplace_scratch_len(), + ); + Self { width, width_size_fft: width_fft, @@ -129,6 +135,7 @@ impl GoodThomasAlgorithm { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len, len, direction, @@ -241,6 +248,31 @@ impl GoodThomasAlgorithm { self.reindex_output(scratch, buffer); } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // Re-index the input, copying from the input to the output in the process + self.reindex_input(input, output); + + // run FFTs of size `width` + self.width_size_fft.process_with_scratch(output, scratch); + + let (scratch, inner_scratch) = scratch.split_at_mut(self.len()); + + // transpose + transpose::transpose(output, scratch, self.width, self.height); + + // run FFTs of size 'height' + self.height_size_fft + .process_with_scratch(scratch, inner_scratch); + + // Re-index the output, copying from the input to the output in the process + self.reindex_output(scratch, output); + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -279,7 +311,8 @@ boilerplate_fft!( GoodThomasAlgorithm, |this: &GoodThomasAlgorithm<_>| this.len, |this: &GoodThomasAlgorithm<_>| this.inplace_scratch_len, - |this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len + |this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len, + |this: &GoodThomasAlgorithm<_>| this.immut_scratch_len ); /// Implementation of the Good-Thomas Algorithm, specialized for smaller input sizes @@ -384,6 +417,38 @@ impl GoodThomasAlgorithmSmall { } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // These asserts are for the unsafe blocks down below. we're relying on the optimizer to get rid of this assert + assert_eq!(self.len(), input.len()); + assert_eq!(self.len(), output.len()); + + let (input_map, output_map) = self.input_output_map.split_at(self.len()); + + // copy the input using our reordering mapping + for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) { + *output_element = input[input_index]; + } + + // run FFTs of size `width` + self.width_size_fft.process_with_scratch(output, scratch); + + // transpose + unsafe { array_utils::transpose_small(self.width, self.height, output, scratch) }; + + // run FFTs of size 'height' + self.height_size_fft.process_with_scratch(scratch, output); + + // copy to the output, using our output redordeing mapping + for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) { + output[output_index] = *input_element; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -448,7 +513,8 @@ boilerplate_fft!( GoodThomasAlgorithmSmall, |this: &GoodThomasAlgorithmSmall<_>| this.width * this.height, |this: &GoodThomasAlgorithmSmall<_>| this.len(), - |_| 0 + |_| 0, + |this: &GoodThomasAlgorithmSmall<_>| this.len() ); #[cfg(test)] @@ -532,12 +598,15 @@ mod unit_tests { for &len in &scratch_lengths { for &inplace_scratch in &scratch_lengths { for &outofplace_scratch in &scratch_lengths { - inner_ffts.push(Arc::new(BigScratchAlgorithm { - len, - inplace_scratch, - outofplace_scratch, - direction: FftDirection::Forward, - }) as Arc>); + for &immut_scratch in &scratch_lengths { + inner_ffts.push(Arc::new(BigScratchAlgorithm { + len, + inplace_scratch, + outofplace_scratch, + immut_scratch, + direction: FftDirection::Forward, + }) as Arc>); + } } } } @@ -565,6 +634,16 @@ mod unit_tests { &mut outofplace_output, &mut outofplace_scratch, ); + + let immut_input = vec![Complex::zero(); fft.len()]; + let mut immut_output = vec![Complex::zero(); fft.len()]; + let mut immut_scratch = vec![Complex::zero(); fft.get_immutable_scratch_len()]; + + fft.process_immutable_with_scratch( + &immut_input, + &mut immut_output, + &mut immut_scratch, + ); } } } diff --git a/src/algorithm/mixed_radix.rs b/src/algorithm/mixed_radix.rs index 808ff0f..66405eb 100644 --- a/src/algorithm/mixed_radix.rs +++ b/src/algorithm/mixed_radix.rs @@ -44,6 +44,7 @@ pub struct MixedRadix { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, direction: FftDirection, } @@ -103,6 +104,11 @@ impl MixedRadix { width_outofplace_scratch, ); + let immut_scratch_len = max( + len + width_fft.get_inplace_scratch_len(), + height_fft.get_inplace_scratch_len(), + ); + Self { twiddles: twiddles.into_boxed_slice(), @@ -114,6 +120,7 @@ impl MixedRadix { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len, direction, } @@ -151,6 +158,37 @@ impl MixedRadix { transpose::transpose(scratch, buffer, self.width, self.height); } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch_raw: &mut [Complex], + ) { + // STEP 1: transpose + transpose::transpose(input, output, self.width, self.height); + + // STEP 2: perform FFTs of size `height` + self.height_size_fft + .process_with_scratch(output, scratch_raw); + + // STEP 3: Apply twiddle factors + for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { + *element = *element * twiddle; + } + + let (scratch, inner_scratch) = scratch_raw.split_at_mut(self.len()); + + // STEP 4: transpose again + transpose::transpose(output, scratch, self.height, self.width); + + // STEP 5: perform FFTs of size `width` + self.width_size_fft + .process_with_scratch(scratch, inner_scratch); + + // STEP 6: transpose again + transpose::transpose(scratch, output, self.width, self.height); + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -196,7 +234,8 @@ boilerplate_fft!( MixedRadix, |this: &MixedRadix<_>| this.twiddles.len(), |this: &MixedRadix<_>| this.inplace_scratch_len, - |this: &MixedRadix<_>| this.outofplace_scratch_len + |this: &MixedRadix<_>| this.outofplace_scratch_len, + |this: &MixedRadix<_>| this.immut_scratch_len ); /// Implementation of the Mixed-Radix FFT algorithm, specialized for smaller input sizes @@ -302,6 +341,34 @@ impl MixedRadixSmall { unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) }; } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // SIX STEP FFT: + // STEP 1: transpose + unsafe { array_utils::transpose_small(self.width, self.height, input, output) }; + + // STEP 2: perform FFTs of size `height` + self.height_size_fft.process_with_scratch(output, scratch); + + // STEP 3: Apply twiddle factors + for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { + *element = *element * twiddle; + } + + // STEP 4: transpose again + unsafe { array_utils::transpose_small(self.height, self.width, output, scratch) }; + + // STEP 5: perform FFTs of size `width` + self.width_size_fft.process_with_scratch(scratch, output); + + // STEP 6: transpose again + unsafe { array_utils::transpose_small(self.width, self.height, scratch, output) }; + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -334,7 +401,8 @@ boilerplate_fft!( MixedRadixSmall, |this: &MixedRadixSmall<_>| this.twiddles.len(), |this: &MixedRadixSmall<_>| this.len(), - |_| 0 + |_| 0, + |this: &MixedRadixSmall<_>| this.len() ); #[cfg(test)] @@ -393,12 +461,15 @@ mod unit_tests { for &len in &scratch_lengths { for &inplace_scratch in &scratch_lengths { for &outofplace_scratch in &scratch_lengths { - inner_ffts.push(Arc::new(BigScratchAlgorithm { - len, - inplace_scratch, - outofplace_scratch, - direction: FftDirection::Forward, - }) as Arc>); + for &immut_scratch in &scratch_lengths { + inner_ffts.push(Arc::new(BigScratchAlgorithm { + len, + inplace_scratch, + outofplace_scratch, + immut_scratch, + direction: FftDirection::Forward, + }) as Arc>); + } } } } @@ -421,6 +492,16 @@ mod unit_tests { &mut outofplace_output, &mut outofplace_scratch, ); + + let immut_input = vec![Complex::zero(); fft.len()]; + let mut immut_output = vec![Complex::zero(); fft.len()]; + let mut immut_scratch = vec![Complex::zero(); fft.get_immutable_scratch_len()]; + + fft.process_immutable_with_scratch( + &immut_input, + &mut immut_output, + &mut immut_scratch, + ); } } } diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index 7f059dd..3a4df60 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -50,6 +50,8 @@ pub struct RadersAlgorithm { len: StrengthReducedUsize, inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, + direction: FftDirection, } @@ -100,6 +102,8 @@ impl RadersAlgorithm { } else { required_inner_scratch }; + let inplace_scratch_len = inner_fft_len + extra_inner_scratch; + let immut_scratch_len = inner_fft_len + required_inner_scratch; //precompute a FFT of our reordered twiddle factors let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch]; @@ -113,12 +117,60 @@ impl RadersAlgorithm { primitive_root_inverse, len: reduced_len, - inplace_scratch_len: inner_fft_len + extra_inner_scratch, + inplace_scratch_len, outofplace_scratch_len: extra_inner_scratch, + immut_scratch_len, direction, } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // The first output element is just the sum of all the input elements, and we need to store off the first input value + let (output_first, output) = output.split_first_mut().unwrap(); + let (input_first, input) = input.split_first().unwrap(); + let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1); + + // copy the input into the scratch space, reordering as we go + let mut input_index = 1; + for output_element in scratch.iter_mut() { + input_index = (input_index * self.primitive_root) % self.len; + + let input_element = input[input_index - 1]; + *output_element = input_element; + } + + self.inner_fft.process_with_scratch(scratch, extra_scratch); + + // output[0] now contains the sum of elements 1..len. We need the sum of all elements, so all we have to do is add the first input + *output_first = *input_first + scratch[0]; + + // multiply the inner result with our cached setup data + // also conjugate every entry. this sets us up to do an inverse FFT + // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) + for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) { + *scratch_cell = (*scratch_cell * twiddle).conj(); + } + + // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. + // Of course, we have to conjugate it, just like we conjugated the complex multiplied above + scratch[0] = scratch[0] + input_first.conj(); + + // execute the second FFT + self.inner_fft.process_with_scratch(scratch, extra_scratch); + + // copy the final values into the output, reordering as we go + let mut output_index = 1; + for scratch_element in scratch { + output_index = (output_index * self.primitive_root_inverse) % self.len; + output[output_index - 1] = scratch_element.conj(); + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -232,7 +284,8 @@ boilerplate_fft!( RadersAlgorithm, |this: &RadersAlgorithm<_>| this.len.get(), |this: &RadersAlgorithm<_>| this.inplace_scratch_len, - |this: &RadersAlgorithm<_>| this.outofplace_scratch_len + |this: &RadersAlgorithm<_>| this.outofplace_scratch_len, + |this: &RadersAlgorithm<_>| this.immut_scratch_len ); #[cfg(test)] diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index d392f8c..182ab4f 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -5,7 +5,7 @@ use num_complex::Complex; use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9}; use crate::algorithm::radixn::butterfly_3; use crate::array_utils::{self, bitreversed_transpose, compute_logarithm}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; @@ -32,8 +32,10 @@ pub struct Radix3 { len: usize, direction: FftDirection, + inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, } impl Radix3 { @@ -108,6 +110,7 @@ impl Radix3 { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len: base_inplace_scratch, } } @@ -118,6 +121,42 @@ impl Radix3 { self.outofplace_scratch_len } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // copy the data into the output vector + if self.len() == self.base_len { + output.copy_from_slice(input); + } else { + bitreversed_transpose::, 3>(self.base_len, input, output); + } + + // Base-level FFTs + self.base_fft.process_with_scratch(output, scratch); + + // cross-FFTs + const ROW_COUNT: usize = 3; + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + while cross_fft_len < output.len() { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = num_columns * (ROW_COUNT - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } + } + + #[inline] fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -154,7 +193,8 @@ impl Radix3 { } } } -boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len); +boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len, |this: &Radix3<_>| this + .immut_scratch_len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 33a804e..b040cfe 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -7,7 +7,7 @@ use crate::algorithm::butterflies::{ }; use crate::algorithm::radixn::butterfly_4; use crate::array_utils::{self, bitreversed_transpose}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; @@ -35,6 +35,7 @@ pub struct Radix4 { direction: FftDirection, inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, } impl Radix4 { @@ -114,6 +115,7 @@ impl Radix4 { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len: base_inplace_scratch, } } @@ -124,6 +126,43 @@ impl Radix4 { self.outofplace_scratch_len } + #[inline] + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // copy the data into the output vector + if self.len() == self.base_len { + output.copy_from_slice(input); + } else { + bitreversed_transpose::, 4>(self.base_len, input, output); + } + + self.base_fft.process_with_scratch(output, scratch); + + // cross-FFTs + const ROW_COUNT: usize = 4; + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + let butterfly4 = Butterfly4::new(self.direction); + + while cross_fft_len < output.len() { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = num_columns * (ROW_COUNT - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -162,7 +201,8 @@ impl Radix4 { } } } -boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len); +boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len, |this: &Radix4<_>| this + .immut_scratch_len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index df56181..a85a6bf 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use num_complex::Complex; use crate::array_utils::{self, factor_transpose, Load, LoadStore, TransposeFactor}; -use crate::common::{fft_error_inplace, fft_error_outofplace, RadixFactor}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace, RadixFactor}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; @@ -43,8 +43,10 @@ pub(crate) struct RadixN { len: usize, direction: FftDirection, + inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, } impl RadixN { @@ -148,6 +150,7 @@ impl RadixN { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len: base_inplace_scratch, } } @@ -158,6 +161,89 @@ impl RadixN { self.outofplace_scratch_len } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if let Some(unroll_factor) = self.factors.first() { + // for performance, we really, really want to unroll the transpose, but we need to make sure the output length is divisible by the unroll amount + // choosing the first factor seems to reliably perform well + match unroll_factor.factor { + RadixFactor::Factor2 => { + factor_transpose::, 2>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor3 => { + factor_transpose::, 3>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor4 => { + factor_transpose::, 4>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor5 => { + factor_transpose::, 5>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor6 => { + factor_transpose::, 6>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor7 => { + factor_transpose::, 7>(self.base_len, input, output, &self.factors) + } + } + } else { + // no factors, so just pass data straight to our base + output.copy_from_slice(input); + } + + // Base-level FFTs + self.base_fft.process_with_scratch(output, scratch); + + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + for factor in self.butterflies.iter() { + let cross_fft_columns = cross_fft_len; + cross_fft_len *= factor.radix(); + + match factor { + InternalRadixFactor::Factor2(butterfly2) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_2(data, layer_twiddles, cross_fft_columns, butterfly2) } + } + } + InternalRadixFactor::Factor3(butterfly3) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_3(data, layer_twiddles, cross_fft_columns, butterfly3) } + } + } + InternalRadixFactor::Factor4(butterfly4) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_4(data, layer_twiddles, cross_fft_columns, butterfly4) } + } + } + InternalRadixFactor::Factor5(butterfly5) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_5(data, layer_twiddles, cross_fft_columns, butterfly5) } + } + } + InternalRadixFactor::Factor6(butterfly6) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_6(data, layer_twiddles, cross_fft_columns, butterfly6) } + } + } + InternalRadixFactor::Factor7(butterfly7) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_7(data, layer_twiddles, cross_fft_columns, butterfly7) } + } + } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = cross_fft_columns * (factor.radix() - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -243,7 +329,8 @@ impl RadixN { } } } -boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len); +boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len, |this: &RadixN<_>| this + .immut_scratch_len); #[inline(never)] pub(crate) unsafe fn butterfly_2( diff --git a/src/array_utils.rs b/src/array_utils.rs index 8f0fe3c..4058987 100644 --- a/src/array_utils.rs +++ b/src/array_utils.rs @@ -145,7 +145,7 @@ mod unit_tests { // Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations // Returns Ok() if every element ended up in a chunk, Err() if there was a remainder -pub fn iter_chunks( +pub fn iter_chunks_mut( mut buffer: &mut [T], chunk_size: usize, mut chunk_fn: impl FnMut(&mut [T]), @@ -169,20 +169,62 @@ pub fn iter_chunks( // Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations // Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder pub fn iter_chunks_zipped( + mut buffer1: &[T], + mut buffer2: &mut [T], + chunk_size: usize, + mut chunk_fn: impl FnMut(&[T], &mut [T]), +) -> Result<(), ()> { + // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size + let uneven = match buffer1.len().cmp(&buffer2.len()) { + std::cmp::Ordering::Less => { + buffer2 = &mut buffer2[..buffer1.len()]; + true + } + std::cmp::Ordering::Equal => false, + std::cmp::Ordering::Greater => { + buffer1 = &buffer1[..buffer2.len()]; + true + } + }; + + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size { + let (head1, tail1) = buffer1.split_at(chunk_size); + buffer1 = tail1; + + let (head2, tail2) = buffer2.split_at_mut(chunk_size); + buffer2 = tail2; + + chunk_fn(head1, head2); + } + + // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder + if !uneven && buffer1.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations +// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder +pub fn iter_chunks_zipped_mut( mut buffer1: &mut [T], mut buffer2: &mut [T], chunk_size: usize, mut chunk_fn: impl FnMut(&mut [T], &mut [T]), ) -> Result<(), ()> { // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size - let uneven = if buffer1.len() > buffer2.len() { - buffer1 = &mut buffer1[..buffer2.len()]; - true - } else if buffer2.len() < buffer1.len() { - buffer2 = &mut buffer2[..buffer1.len()]; - true - } else { - false + let uneven = match buffer1.len().cmp(&buffer2.len()) { + std::cmp::Ordering::Less => { + buffer2 = &mut buffer2[..buffer1.len()]; + true + } + std::cmp::Ordering::Equal => false, + std::cmp::Ordering::Greater => { + buffer1 = &mut buffer1[..buffer2.len()]; + true + } }; // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index 87f9c10..e0c8539 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -5,9 +5,9 @@ use std::mem::MaybeUninit; use num_complex::Complex; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -37,6 +37,41 @@ macro_rules! boilerplate_fft_simd_butterfly { } impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices + let input_slice = workaround_transmute(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_f32(DoubleBuf { + input: input_slice, + output: output_slice, + }); + } + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -56,7 +91,7 @@ macro_rules! boilerplate_fft_simd_butterfly { |in_chunk, out_chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_f32(DoubleBuf { input: input_slice, @@ -79,7 +114,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices self.perform_fft_f32(workaround_transmute_mut::<_, Complex>(chunk)); @@ -100,6 +135,10 @@ macro_rules! boilerplate_fft_simd_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -156,11 +195,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } #[inline] - fn perform_fft_out_of_place( - &self, - input: &mut [Complex], - output: &mut [Complex], - ) { + fn perform_fft_immut(&self, input: &[Complex], output: &mut [Complex]) { // Perform the column FFTs // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available unsafe { self.column_butterflies_and_transpose(input, output) }; @@ -169,8 +204,47 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available unsafe { self.row_butterflies(output) }; } + + #[inline] + fn perform_fft_out_of_place( + &self, + input: &mut [Complex], + output: &mut [Complex], + ) { + self.perform_fft_immut(input, output); + } } impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = + unsafe { array_utils::workaround_transmute(input) }; + let transmuted_output: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(output) }; + let result = array_utils::iter_chunks_zipped( + transmuted_input, + transmuted_output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -188,7 +262,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( transmuted_input, transmuted_output, self.len(), @@ -216,7 +290,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(buffer) }; let transmuted_scratch: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, transmuted_scratch) }); @@ -234,6 +308,10 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index c15d545..e6bc7d0 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -5,9 +5,9 @@ use std::mem::MaybeUninit; use num_complex::Complex; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -35,6 +35,41 @@ macro_rules! boilerplate_fft_simd_butterfly { } } impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices + let input_slice = workaround_transmute(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_f64(DoubleBuf { + input: input_slice, + output: output_slice, + }); + } + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -54,7 +89,7 @@ macro_rules! boilerplate_fft_simd_butterfly { |in_chunk, out_chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_f64(DoubleBuf { input: input_slice, @@ -77,7 +112,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices self.perform_fft_f64(workaround_transmute_mut::<_, Complex>(chunk)); @@ -98,6 +133,10 @@ macro_rules! boilerplate_fft_simd_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -156,11 +195,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } #[inline] - fn perform_fft_out_of_place( - &self, - input: &mut [Complex], - output: &mut [Complex], - ) { + fn perform_fft_immut(&self, input: &[Complex], output: &mut [Complex]) { // Perform the column FFTs // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available unsafe { self.column_butterflies_and_transpose(input, output) }; @@ -169,8 +204,47 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available unsafe { self.row_butterflies(output) }; } + + #[inline] + fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + ) { + self.perform_fft_immut(input, output); + } } impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = + unsafe { array_utils::workaround_transmute(input) }; + let transmuted_output: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(output) }; + let result = array_utils::iter_chunks_zipped( + transmuted_input, + transmuted_output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -188,7 +262,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( transmuted_input, transmuted_output, self.len(), @@ -216,7 +290,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(buffer) }; let transmuted_scratch: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, transmuted_scratch) }); @@ -234,6 +308,10 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index 30e8f8e..d85020b 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -5,7 +5,7 @@ use num_complex::Complex; use num_integer::div_ceil; use num_traits::Zero; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{array_utils, twiddles, FftDirection}; use crate::{Direction, Fft, FftNum, Length}; @@ -144,6 +144,7 @@ impl BluesteinsAvx { inplace_scratch_len: required_scratch, outofplace_scratch_len: required_scratch, + immut_scratch_len: required_scratch, direction, }, @@ -311,7 +312,7 @@ impl BluesteinsAvx { } } - fn perform_fft_out_of_place( + fn perform_fft_immut( &self, input: &[Complex], output: &mut [Complex], @@ -365,6 +366,15 @@ impl BluesteinsAvx { self.finalize_bluesteins(transmuted_inner_input, transmuted_output) } } + + fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, scratch); + } } #[cfg(test)] diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index ed6de0f..a64476c 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -5,7 +5,7 @@ use num_complex::Complex; use num_integer::div_ceil; use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{Direction, Fft, FftDirection, FftNum, Length}; use super::{AvxNum, CommonSimdData}; @@ -69,6 +69,43 @@ macro_rules! boilerplate_mixedradix { } } + #[inline] + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // Perform the column FFTs + // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available + let (scratch, inner_scratch) = scratch.split_at_mut(input.len()); + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch); + + self.perform_column_butterflies_immut(transmuted_input, transmuted_output); + } + + // process the row FFTs. If extra scratch was provided, pass it in. Otherwise, use the output. + self.common_data + .inner_fft + .process_with_scratch(scratch, inner_scratch); + + // Transpose + // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + + self.transpose(transmuted_input, transmuted_output) + } + } + #[inline] fn perform_fft_out_of_place( &self, @@ -138,11 +175,13 @@ macro_rules! mixedradix_gen_data { let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len(); let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len(); + let immut_scratch_len = len + $inner_fft.get_inplace_scratch_len(); CommonSimdData { twiddles: twiddles.into_boxed_slice(), inplace_scratch_len: len + inner_outofplace_scratch, outofplace_scratch_len: if inner_inplace_scratch > len { inner_inplace_scratch } else { 0 }, + immut_scratch_len, inner_fft: $inner_fft, len, direction, @@ -152,6 +191,124 @@ macro_rules! mixedradix_gen_data { macro_rules! mixedradix_column_butterflies { ($row_count: expr, $butterfly_fn: expr, $butterfly_fn_lo: expr) => { + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_column_butterflies_immut( + &self, + input: impl AvxArray, + mut buffer: impl AvxArrayMut, + ) { + // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc + const ROW_COUNT: usize = $row_count; + const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1; + + let len_per_row = self.len() / ROW_COUNT; + let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR; + + // process the column FFTs + for (c, twiddle_chunk) in self + .common_data + .twiddles + .chunks_exact(TWIDDLES_PER_COLUMN) + .take(chunk_count) + .enumerate() + { + let index_base = c * A::VectorType::COMPLEX_PER_VECTOR; + + // Load columns from the input into registers + let mut columns = [AvxVector::zero(); ROW_COUNT]; + for i in 0..ROW_COUNT { + columns[i] = input.load_complex(index_base + len_per_row * i); + } + + // apply our butterfly function down the columns + let output = $butterfly_fn(columns, self); + + // always write the first row directly back without twiddles + buffer.store_complex(output[0], index_base); + + // for every other row, apply twiddle factors and then write back to memory + for i in 1..ROW_COUNT { + let twiddle = twiddle_chunk[i - 1]; + let output = AvxVector::mul_complex(twiddle, output[i]); + buffer.store_complex(output, index_base + len_per_row * i); + } + } + + // finally, we might have a remainder chunk + // Normally, we can fit COMPLEX_PER_VECTOR complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns + let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR; + if partial_remainder > 0 { + let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR; + let partial_remainder_twiddle_base = + self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN; + let final_twiddle_chunk = + &self.common_data.twiddles[partial_remainder_twiddle_base..]; + + if partial_remainder > 2 { + // Load 3 columns into full AVX vectors to process our remainder + let mut columns = [AvxVector::zero(); ROW_COUNT]; + for i in 0..ROW_COUNT { + columns[i] = + input.load_partial3_complex(partial_remainder_base + len_per_row * i); + } + + // apply our butterfly function down the columns + let mid = $butterfly_fn(columns, self); + + // always write the first row without twiddles + buffer.store_partial3_complex(mid[0], partial_remainder_base); + + // for the remaining rows, apply twiddle factors and then write back to memory + for i in 1..ROW_COUNT { + let twiddle = final_twiddle_chunk[i - 1]; + let output = AvxVector::mul_complex(twiddle, mid[i]); + buffer.store_partial3_complex( + output, + partial_remainder_base + len_per_row * i, + ); + } + } else { + // Load 1 or 2 columns into half vectors to process our remainder. Thankfully, the compiler is smart enough to eliminate this branch on f64, since the partial remainder can only possibly be 1 + let mut columns = [AvxVector::zero(); ROW_COUNT]; + if partial_remainder == 1 { + for i in 0..ROW_COUNT { + columns[i] = input + .load_partial1_complex(partial_remainder_base + len_per_row * i); + } + } else { + for i in 0..ROW_COUNT { + columns[i] = input + .load_partial2_complex(partial_remainder_base + len_per_row * i); + } + } + + // apply our butterfly function down the columns + let mut mid = $butterfly_fn_lo(columns, self); + + // apply twiddle factors + for i in 1..ROW_COUNT { + mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]); + } + + // store output + if partial_remainder == 1 { + for i in 0..ROW_COUNT { + buffer.store_partial1_complex( + mid[i], + partial_remainder_base + len_per_row * i, + ); + } + } else { + for i in 0..ROW_COUNT { + buffer.store_partial2_complex( + mid[i], + partial_remainder_base + len_per_row * i, + ); + } + } + } + } + } #[target_feature(enable = "avx", enable = "fma")] unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut) { // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc @@ -862,6 +1019,127 @@ impl MixedRadix16xnAvx { } } } + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_column_butterflies_immut( + &self, + input: impl AvxArray, + mut buffer: impl AvxArrayMut, + ) { + // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc + const ROW_COUNT: usize = 16; + const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1; + + let len_per_row = self.len() / ROW_COUNT; + let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR; + + // process the column FFTs + for (c, twiddle_chunk) in self + .common_data + .twiddles + .chunks_exact(TWIDDLES_PER_COLUMN) + .take(chunk_count) + .enumerate() + { + let index_base = c * A::VectorType::COMPLEX_PER_VECTOR; + + column_butterfly16_loadfn!( + |index| input.load_complex(index_base + len_per_row * index), + |mut data, index| { + if index > 0 { + data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]); + } + buffer.store_complex(data, index_base + len_per_row * index) + }, + self.twiddles_butterfly16, + self.twiddles_butterfly4 + ); + } + + // finally, we might have a single partial chunk. + // Normally, we can fit 4 complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns + let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR; + if partial_remainder > 0 { + let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR; + let partial_remainder_twiddle_base = + self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN; + let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..]; + + match partial_remainder { + 1 => { + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + buffer.store_partial1_complex(input.load_partial1_complex(cs), cs); + } + column_butterfly16_loadfn!( + |index| buffer + .load_partial1_complex(partial_remainder_base + len_per_row * index), + |mut data, index| { + if index > 0 { + let twiddle: A::VectorType = final_twiddle_chunk[index - 1]; + data = AvxVector::mul_complex(data, twiddle.lo()); + } + buffer.store_partial1_complex( + data, + partial_remainder_base + len_per_row * index, + ) + }, + [ + self.twiddles_butterfly16[0].lo(), + self.twiddles_butterfly16[1].lo() + ], + self.twiddles_butterfly4.lo() + ); + } + 2 => { + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + buffer.store_partial2_complex(input.load_partial2_complex(cs), cs); + } + column_butterfly16_loadfn!( + |index| buffer + .load_partial2_complex(partial_remainder_base + len_per_row * index), + |mut data, index| { + if index > 0 { + let twiddle: A::VectorType = final_twiddle_chunk[index - 1]; + data = AvxVector::mul_complex(data, twiddle.lo()); + } + buffer.store_partial2_complex( + data, + partial_remainder_base + len_per_row * index, + ) + }, + [ + self.twiddles_butterfly16[0].lo(), + self.twiddles_butterfly16[1].lo() + ], + self.twiddles_butterfly4.lo() + ); + } + 3 => { + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + buffer.store_partial3_complex(input.load_partial3_complex(cs), cs); + } + column_butterfly16_loadfn!( + |index| buffer + .load_partial3_complex(partial_remainder_base + len_per_row * index), + |mut data, index| { + if index > 0 { + data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]); + } + buffer.store_partial3_complex( + data, + partial_remainder_base + len_per_row * index, + ) + }, + self.twiddles_butterfly16, + self.twiddles_butterfly4 + ); + } + _ => unreachable!(), + } + } + } mixedradix_transpose!(16, AvxVector::transpose16_packed, AvxVector::transpose16_packed, diff --git a/src/avx/avx_raders.rs b/src/avx/avx_raders.rs index 34b76b1..439395e 100644 --- a/src/avx/avx_raders.rs +++ b/src/avx/avx_raders.rs @@ -8,7 +8,7 @@ use num_traits::Zero; use primal_check::miller_rabin; use strength_reduce::StrengthReducedUsize; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{array_utils, FftDirection}; use crate::{math_utils, twiddles}; use crate::{Direction, Fft, FftNum, Length}; @@ -107,6 +107,7 @@ pub struct RadersAvx2 { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, direction: FftDirection, _phantom: std::marker::PhantomData, @@ -267,6 +268,7 @@ impl RadersAvx2 { }) .collect::>() }; + let inplace_scratch_len = len + extra_inner_scratch; Self { input_index_multiplier, input_index_init, @@ -278,8 +280,9 @@ impl RadersAvx2 { len, - inplace_scratch_len: len + extra_inner_scratch, + inplace_scratch_len, outofplace_scratch_len: extra_inner_scratch, + immut_scratch_len: inner_fft_len + required_inner_scratch + 1, direction, _phantom: std::marker::PhantomData, @@ -351,6 +354,64 @@ impl RadersAvx2 { } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + self.prepare_raders(transmuted_input, transmuted_output) + } + + let (first_input, _) = input.split_first().unwrap(); + let (first_output, inner_output) = output.split_first_mut().unwrap(); + let (scratch2, extra_scratch) = scratch.split_at_mut(self.len()); + let (_, scratch) = scratch2.split_first_mut().unwrap(); + + self.inner_fft.process_with_scratch(inner_output, scratch); + + // inner_output[0] now contains the sum of elements 1..n. we want the sum of all inputs, so all we need to do is add the first input + *first_output = inner_output[0] + *first_input; + + // multiply the inner result with our cached setup data + // also conjugate every entry. this sets us up to do an inverse FFT + // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch); + let transmuted_inner_output: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_output); + avx_vector::pairwise_complex_mul_conjugated( + transmuted_inner_output, + transmuted_inner_input, + &self.twiddles, + ) + }; + + // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. + // Of course, we have to conjugate it, just like we conjugated the complex multiplied above + scratch[0] = scratch[0] + first_input.conj(); + + self.inner_fft.process_with_scratch(scratch, extra_scratch); + scratch2[0] = *first_input; + + // copy the final values into the output, reordering as we go + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch2); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + self.finalize_raders(transmuted_input, transmuted_output); + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -385,13 +446,13 @@ impl RadersAvx2 { // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_inner_input: &mut [Complex] = - array_utils::workaround_transmute_mut(inner_input); let transmuted_inner_output: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + let transmuted_inner_input: &mut [Complex] = array_utils::workaround_transmute_mut(inner_output); avx_vector::pairwise_complex_mul_conjugated( - transmuted_inner_output, transmuted_inner_input, + transmuted_inner_output, &self.twiddles, ) }; @@ -479,7 +540,8 @@ boilerplate_avx_fft!( RadersAvx2, |this: &RadersAvx2<_, _>| this.len, |this: &RadersAvx2<_, _>| this.inplace_scratch_len, - |this: &RadersAvx2<_, _>| this.outofplace_scratch_len + |this: &RadersAvx2<_, _>| this.outofplace_scratch_len, + |this: &RadersAvx2<_, _>| this.immut_scratch_len ); #[cfg(test)] diff --git a/src/avx/avx_vector.rs b/src/avx/avx_vector.rs index c5f4c94..97807f2 100644 --- a/src/avx/avx_vector.rs +++ b/src/avx/avx_vector.rs @@ -2265,7 +2265,7 @@ pub unsafe fn pairwise_complex_mul_conjugated( multiplier.len(), input.len() ); // Assert to convince the compiler to omit bounds checks inside the loop - assert!(input.len() == output.len()); // Assert to convince the compiler to omit bounds checks inside the loop + assert_eq!(input.len(), output.len()); // Assert to convince the compiler to omit bounds checks inside the loop let main_loop_count = input.len() / T::VectorType::COMPLEX_PER_VECTOR; let remainder_count = input.len() % T::VectorType::COMPLEX_PER_VECTOR; diff --git a/src/avx/mod.rs b/src/avx/mod.rs index dc3dbea..4aec9df 100644 --- a/src/avx/mod.rs +++ b/src/avx/mod.rs @@ -23,13 +23,57 @@ struct CommonSimdData { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, direction: FftDirection, } macro_rules! boilerplate_avx_fft { - ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { + ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr, $immut_scratch_len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + let required_scratch = self.get_immutable_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ) + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -53,7 +97,7 @@ macro_rules! boilerplate_avx_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -88,7 +132,7 @@ macro_rules! boilerplate_avx_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -111,6 +155,10 @@ macro_rules! boilerplate_avx_fft { fn get_outofplace_scratch_len(&self) -> usize { $out_of_place_scratch_len_fn(self) } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + $immut_scratch_len_fn(self) + } } impl Length for $struct_name { #[inline(always)] @@ -130,6 +178,52 @@ macro_rules! boilerplate_avx_fft { macro_rules! boilerplate_avx_fft_commondata { ($struct_name:ident) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.common_data.immut_scratch_len; + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + self.len(), + input.len(), + output.len(), + self.get_immutable_scratch_len(), + scratch.len(), + ); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + self.len(), + input.len(), + output.len(), + self.get_immutable_scratch_len(), + scratch.len(), + ); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -157,7 +251,7 @@ macro_rules! boilerplate_avx_fft_commondata { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -196,7 +290,7 @@ macro_rules! boilerplate_avx_fft_commondata { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -219,6 +313,10 @@ macro_rules! boilerplate_avx_fft_commondata { fn get_outofplace_scratch_len(&self) -> usize { self.common_data.outofplace_scratch_len } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + self.common_data.immut_scratch_len + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/common.rs b/src/common.rs index 52eb5e5..1670ebf 100644 --- a/src/common.rs +++ b/src/common.rs @@ -70,9 +70,87 @@ pub fn fft_error_outofplace( ); } +// Prints an error raised by an in-place FFT algorithm's `process_inplace` method +// Marked cold and inline never to keep all formatting code out of the many monomorphized process_inplace methods +#[cold] +#[inline(never)] +pub fn fft_error_immut( + expected_len: usize, + actual_input: usize, + actual_output: usize, + expected_scratch: usize, + actual_scratch: usize, +) { + assert_eq!(actual_input, actual_output, "Provided FFT input buffer and output buffer must have the same length. Got input.len() = {}, output.len() = {}", actual_input, actual_output); + assert!( + actual_input >= expected_len, + "Provided FFT buffer was too small. Expected len = {}, got len = {}", + expected_len, + actual_input + ); + assert_eq!( + actual_input % expected_len, + 0, + "Input FFT buffer must be a multiple of FFT length. Expected multiple of {}, got len = {}", + expected_len, + actual_input + ); + assert!( + actual_scratch >= expected_scratch, + "Not enough scratch space was provided. Expected scratch len >= {}, got scratch len = {}", + expected_scratch, + actual_scratch + ); +} + macro_rules! boilerplate_fft_oop { - ($struct_name:ident, $len_fn:expr) => { + ($struct_name:ident, $len_fn:expr, $immut_scratch_len:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.get_immutable_scratch_len(); + if input.len() < self.len() + || output.len() != input.len() + || scratch.len() < required_scratch + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -99,7 +177,7 @@ macro_rules! boilerplate_fft_oop { return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -132,7 +210,7 @@ macro_rules! boilerplate_fft_oop { } let (scratch, extra_scratch) = scratch.split_at_mut(self.len()); - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, extra_scratch); chunk.copy_from_slice(scratch); }); @@ -156,6 +234,10 @@ macro_rules! boilerplate_fft_oop { fn get_outofplace_scratch_len(&self) -> usize { self.outofplace_scratch_len() } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + $immut_scratch_len(self) + } } impl Length for $struct_name { #[inline(always)] @@ -173,8 +255,49 @@ macro_rules! boilerplate_fft_oop { } macro_rules! boilerplate_fft { - ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { + ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr, $immut_scratch_len:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + let required_scratch = self.get_immutable_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + crate::common::fft_error_immut( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + if result.is_err() { + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -202,7 +325,7 @@ macro_rules! boilerplate_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -241,7 +364,7 @@ macro_rules! boilerplate_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -264,6 +387,10 @@ macro_rules! boilerplate_fft { fn get_outofplace_scratch_len(&self) -> usize { $out_of_place_scratch_len_fn(self) } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + $immut_scratch_len(self) + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/lib.rs b/src/lib.rs index a897d42..d88f330 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -234,6 +234,25 @@ pub trait Fft: Length + Direction + Sync + Send { scratch: &mut [Complex], ); + /// Divides `input` and `output` into chunks of `self.len()`, and computes a FFT on each chunk while + /// keeping `input` untouched. + /// + /// This method uses the `scratch` buffer as scratch space, so the contents should be considered garbage after calling. + /// + /// # Panics + /// + /// This method panics if: + /// - `output.len() ! input.len()` + /// - `input.len() % self.len() > 0` + /// - `input.len() < self.len()` + /// - `scratch.len() < get_immutable_scratch_len()` + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ); + /// Returns the size of the scratch buffer required by `process_with_scratch` /// /// For most FFT sizes, this method will return `self.len()`. For a few small sizes it will return 0, and for some special FFT sizes @@ -247,6 +266,14 @@ pub trait Fft: Length + Direction + Sync + Send { /// (Sizes that require the use of Bluestein's Algorithm), this may return a scratch size larger than `self.len()`. /// The returned value may change from one version of RustFFT to the next. fn get_outofplace_scratch_len(&self) -> usize; + + /// Returns the size of the scratch buffer required by `process_immutable_with_scratch` + /// + /// For most FFT sizes, this method will return something between self.len() and self.len() * 2. + /// For a few small sizes it will return 0, and for some special FFT sizes + /// (Sizes that require the use of Bluestein's Algorithm), this may return a scratch size larger than `self.len()`. + /// The returned value may change from one version of RustFFT to the next. + fn get_immutable_scratch_len(&self) -> usize; } // Algorithms implemented to use AVX instructions. Only compiled on x86_64, and only compiled if the "avx" feature flag is set. diff --git a/src/neon/mod.rs b/src/neon/mod.rs index ee8039b..20b936c 100644 --- a/src/neon/mod.rs +++ b/src/neon/mod.rs @@ -12,10 +12,6 @@ mod neon_utils; pub mod neon_planner; -pub use self::neon_butterflies::*; -pub use self::neon_prime_butterflies::*; -pub use self::neon_radix4::*; - use std::arch::aarch64::{float32x4_t, float64x2_t}; use crate::FftNum; diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index dc84ffa..571e8da 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -43,7 +43,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -52,10 +52,9 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { Ok(()) } - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); @@ -64,7 +63,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { output, 2 * self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_parallel_fft_contiguous(DoubleBuf { input: input_slice, @@ -73,10 +72,10 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { }, ); if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); + let input_slice = crate::array_utils::workaround_transmute(input); let output_slice = workaround_transmute_mut(output); self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], + input: &input_slice[len - self.len()..], output: &mut output_slice[len - self.len()..], }) } @@ -101,20 +100,19 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - //#[target_feature(enable = "neon")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -130,6 +128,25 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { macro_rules! boilerplate_fft_neon_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -172,6 +189,10 @@ macro_rules! boilerplate_fft_neon_common_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/neon/neon_common.rs b/src/neon/neon_common.rs index 826a320..635cf30 100644 --- a/src/neon/neon_common.rs +++ b/src/neon/neon_common.rs @@ -57,6 +57,37 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_neon_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -74,7 +105,7 @@ macro_rules! boilerplate_fft_neon_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +140,7 @@ macro_rules! boilerplate_fft_neon_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -133,6 +164,10 @@ macro_rules! boilerplate_fft_neon_oop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -180,7 +215,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -219,7 +254,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/neon/neon_prime_butterflies.rs b/src/neon/neon_prime_butterflies.rs index 688b5a2..8ef4810 100644 --- a/src/neon/neon_prime_butterflies.rs +++ b/src/neon/neon_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index 841a782..2fbfe4e 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::sync::Arc; use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; @@ -83,8 +83,7 @@ impl NeonRadix4 { } } - //#[target_feature(enable = "neon")] - unsafe fn perform_fft_out_of_place( + unsafe fn perform_fft_immut( &self, input: &[Complex], output: &mut [Complex], @@ -126,6 +125,16 @@ impl NeonRadix4 { cross_fft_len *= ROW_COUNT; } } + + //#[target_feature(enable = "neon")] + unsafe fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, _scratch); + } } boilerplate_fft_neon_oop!(NeonRadix4, |this: &NeonRadix4<_, _>| this.len); diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 97c6bcf..7ce8d21 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -4,9 +4,9 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -48,7 +48,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -61,7 +61,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { #[target_feature(enable = "sse4.1")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); @@ -70,7 +70,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { output, 2 * self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_parallel_fft_contiguous(DoubleBuf { input: input_slice, @@ -79,10 +79,10 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { }, ); if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); + let input_slice = crate::array_utils::workaround_transmute(input); let output_slice = workaround_transmute_mut(output); self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], + input: &input_slice[len - self.len()..], output: &mut output_slice[len - self.len()..], }) } @@ -107,7 +107,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -116,11 +116,11 @@ macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { #[target_feature(enable = "sse4.1")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -147,7 +147,7 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -156,11 +156,11 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { #[target_feature(enable = "sse4.1")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -176,6 +176,25 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { macro_rules! boilerplate_fft_sse_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -218,6 +237,10 @@ macro_rules! boilerplate_fft_sse_common_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 0c31995..86593ab 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -57,6 +57,37 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_sse_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -109,7 +140,7 @@ macro_rules! boilerplate_fft_sse_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -133,6 +164,10 @@ macro_rules! boilerplate_fft_sse_oop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -180,7 +215,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -219,7 +254,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/sse/sse_prime_butterflies.rs b/src/sse/sse_prime_butterflies.rs index 2ffaf97..ac9d500 100644 --- a/src/sse/sse_prime_butterflies.rs +++ b/src/sse/sse_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/sse/sse_radix4.rs b/src/sse/sse_radix4.rs index 6d8e27e..43eb397 100644 --- a/src/sse/sse_radix4.rs +++ b/src/sse/sse_radix4.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::sync::Arc; use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; @@ -93,7 +93,7 @@ impl SseRadix4 { } #[target_feature(enable = "sse4.1")] - unsafe fn perform_fft_out_of_place( + unsafe fn perform_fft_immut( &self, input: &[Complex], output: &mut [Complex], @@ -135,6 +135,16 @@ impl SseRadix4 { cross_fft_len *= ROW_COUNT; } } + + #[target_feature(enable = "sse4.1")] + unsafe fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, _scratch); + } } boilerplate_fft_sse_oop!(SseRadix4, |this: &SseRadix4<_, _>| this.len); diff --git a/src/test_utils.rs b/src/test_utils.rs index 1bbb48b..d859d88 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -173,6 +173,39 @@ pub fn check_fft_algorithm( ); } } + + // test process_immutable_with_scratch() + { + let mut input = reference_input.clone(); + let mut scratch = vec![Zero::zero(); fft.get_immutable_scratch_len()]; + let mut output = vec![Zero::zero(); n * len]; + + fft.process_immutable_with_scratch(&input, &mut output, &mut scratch); + + assert!( + compare_vectors(&expected_output, &output), + "process_immutable_with_scratch() failed, length = {}, direction = {}", + len, + direction + ); + + // make sure this algorithm works correctly with dirty scratch + if scratch.len() > 0 { + for item in scratch.iter_mut() { + *item = dirty_scratch_value; + } + input.copy_from_slice(&reference_input); + + fft.process_immutable_with_scratch(&input, &mut output, &mut scratch); + + assert!( + compare_vectors(&expected_output, &output), + "process_immutable_with_scratch() failed the 'dirty scratch' test, length = {}, direction = {}", + len, + direction + ); + } + } } // A fake FFT algorithm that requests much more scratch than it needs. You can use this as an inner FFT to other algorithms to test their scratch-supplying logic @@ -182,10 +215,24 @@ pub struct BigScratchAlgorithm { pub inplace_scratch: usize, pub outofplace_scratch: usize, + pub immut_scratch: usize, pub direction: FftDirection, } impl Fft for BigScratchAlgorithm { + fn process_immutable_with_scratch( + &self, + _input: &[Complex], + _output: &mut [Complex], + scratch: &mut [Complex], + ) { + assert!( + scratch.len() >= self.immut_scratch, + "Not enough immut scratch provided, self={:?}, provided scratch={}", + &self, + scratch.len() + ); + } fn process_with_scratch(&self, _buffer: &mut [Complex], scratch: &mut [Complex]) { assert!( scratch.len() >= self.inplace_scratch, @@ -213,6 +260,9 @@ impl Fft for BigScratchAlgorithm { fn get_outofplace_scratch_len(&self) -> usize { self.outofplace_scratch } + fn get_immutable_scratch_len(&self) -> usize { + self.immut_scratch + } } impl Length for BigScratchAlgorithm { fn len(&self) -> usize { diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 90af2e8..231c434 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -46,7 +46,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -59,7 +59,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { #[target_feature(enable = "simd128")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); @@ -68,7 +68,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { output, 2 * self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_parallel_fft_contiguous(DoubleBuf { input: input_slice, @@ -77,10 +77,10 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { }, ); if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); + let input_slice = crate::array_utils::workaround_transmute(input); let output_slice = workaround_transmute_mut(output); self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], + input: &input_slice[len - self.len()..], output: &mut output_slice[len - self.len()..], }) } @@ -105,7 +105,7 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -114,11 +114,11 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { #[target_feature(enable = "simd128")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -134,6 +134,25 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { macro_rules! boilerplate_fft_wasm_simd_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -176,6 +195,10 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index 787170d..0b46f4c 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -57,6 +57,37 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_wasm_simd_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -74,7 +105,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +140,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -133,6 +164,10 @@ macro_rules! boilerplate_fft_wasm_simd_oop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + self.len() + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/wasm_simd/wasm_simd_prime_butterflies.rs b/src/wasm_simd/wasm_simd_prime_butterflies.rs index c56585e..4d46e6b 100644 --- a/src/wasm_simd/wasm_simd_prime_butterflies.rs +++ b/src/wasm_simd/wasm_simd_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/wasm_simd/wasm_simd_radix4.rs b/src/wasm_simd/wasm_simd_radix4.rs index db60d99..7368257 100644 --- a/src/wasm_simd/wasm_simd_radix4.rs +++ b/src/wasm_simd/wasm_simd_radix4.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::sync::Arc; use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; @@ -83,7 +83,7 @@ impl WasmSimdRadix4 { } #[target_feature(enable = "simd128")] - unsafe fn perform_fft_out_of_place( + unsafe fn perform_fft_immut( &self, input: &[Complex], output: &mut [Complex], @@ -125,6 +125,16 @@ impl WasmSimdRadix4 { cross_fft_len *= ROW_COUNT; } } + + #[target_feature(enable = "simd128")] + unsafe fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, _scratch); + } } boilerplate_fft_wasm_simd_oop!(WasmSimdRadix4, |this: &WasmSimdRadix4<_, _>| this.len); diff --git a/tests/test_immutable.rs b/tests/test_immutable.rs new file mode 100644 index 0000000..029eca5 --- /dev/null +++ b/tests/test_immutable.rs @@ -0,0 +1,52 @@ +use num_complex::Complex; +use rustfft::{FftNum, FftPlanner}; + +const TEST_MAX: usize = 1001; + +#[test] +fn immutable_f32() { + for i in 0..TEST_MAX { + let input = vec![Complex::new(7.0, 8.0); i]; + + let mut_output = fft_wrapper_mut::(&input); + let immut_output = fft_wrapper_immut::(&input); + + assert_eq!(mut_output, immut_output, "{}", i); + } +} + +#[test] +fn immutable_f64() { + for i in 0..TEST_MAX { + let input = vec![Complex::new(7.0, 8.0); i]; + + let mut_output = fft_wrapper_mut::(&input); + let immut_output = fft_wrapper_immut::(&input); + + assert_eq!(mut_output, immut_output, "{}", i); + } +} + +fn fft_wrapper_mut(input: &[Complex]) -> Vec> { + let cz = Complex::new(T::zero(), T::zero()); + let mut plan = FftPlanner::::new(); + let p = plan.plan_fft_forward(input.len()); + + let mut scratch = vec![cz; p.get_inplace_scratch_len()]; + let mut output = input.to_vec(); + + p.process_with_scratch(&mut output, &mut scratch); + output +} + +fn fft_wrapper_immut(input: &[Complex]) -> Vec> { + let cz = Complex::new(T::zero(), T::zero()); + let mut plan = FftPlanner::::new(); + let p = plan.plan_fft_forward(input.len()); + + let mut scratch = vec![cz; p.get_immutable_scratch_len()]; + let mut output = vec![cz; input.len()]; + + p.process_immutable_with_scratch(input, &mut output, &mut scratch); + output +} diff --git a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs index 9038286..6937f5c 100644 --- a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs +++ b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length};