From 026236f9e56c1e56e34d0427f5ee14ab0a43e2aa Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Thu, 1 May 2025 22:37:54 -0500 Subject: [PATCH 01/26] Adding process_immutable_with_scratch method --- benches/bench_rustfft.rs | 7 + benches/bench_rustfft_scalar.rs | 10 + src/algorithm/bluesteins_algorithm.rs | 18 +- src/algorithm/butterflies.rs | 34 ++- src/algorithm/dft.rs | 9 + src/algorithm/good_thomas_algorithm.rs | 96 +++++++- src/algorithm/mixed_radix.rs | 110 ++++++++- src/algorithm/raders_algorithm.rs | 60 ++++- src/algorithm/radix3.rs | 17 +- src/algorithm/radix4.rs | 37 ++++ src/algorithm/radixn.rs | 83 +++++++ src/array_utils.rs | 60 ++++- src/avx/avx32_butterflies.rs | 78 ++++++- src/avx/avx64_butterflies.rs | 80 ++++++- src/avx/avx_bluesteins.rs | 56 +++++ src/avx/avx_mixed_radix.rs | 296 +++++++++++++++++++++++++ src/avx/avx_raders.rs | 68 +++++- src/avx/avx_vector.rs | 2 +- src/avx/mod.rs | 108 ++++++++- src/common.rs | 98 +++++++- src/lib.rs | 26 +++ src/neon/neon_butterflies.rs | 81 ++++++- src/neon/neon_common.rs | 43 +++- src/neon/neon_radix4.rs | 43 ++++ src/sse/sse_butterflies.rs | 47 +++- src/sse/sse_common.rs | 22 +- src/test_utils.rs | 50 +++++ src/wasm_simd/wasm_simd_butterflies.rs | 39 +++- src/wasm_simd/wasm_simd_common.rs | 39 +++- src/wasm_simd/wasm_simd_radix4.rs | 44 ++++ tests/test_immutable.rs | 52 +++++ 31 files changed, 1712 insertions(+), 101 deletions(-) create mode 100644 tests/test_immutable.rs diff --git a/benches/bench_rustfft.rs b/benches/bench_rustfft.rs index 299c72cb..387c5f3e 100644 --- a/benches/bench_rustfft.rs +++ b/benches/bench_rustfft.rs @@ -23,6 +23,13 @@ impl Fft for Noop { fn process_outofplace_with_scratch(&self, _input: &mut [Complex], _output: &mut [Complex], _scratch: &mut [Complex]) {} fn get_inplace_scratch_len(&self) -> usize { self.len } fn get_outofplace_scratch_len(&self) -> usize { 0 } + fn process_immutable_with_scratch( + &self, + _input: &[Complex], + _output: &mut [Complex], + _scratch: &mut [Complex], + ) {} + fn get_immutable_scratch_len(&self) -> usize { 0 } } impl Length for Noop { fn len(&self) -> usize { self.len } diff --git a/benches/bench_rustfft_scalar.rs b/benches/bench_rustfft_scalar.rs index b57a5d35..24eca927 100644 --- a/benches/bench_rustfft_scalar.rs +++ b/benches/bench_rustfft_scalar.rs @@ -33,6 +33,16 @@ impl Fft for Noop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + + fn process_immutable_with_scratch( + &self, + _input: &[Complex], + _output: &mut [Complex], + _scratch: &mut [Complex], + ) { + } + + fn get_immutable_scratch_len(&self) -> usize { 0 } } impl Length for Noop { fn len(&self) -> usize { diff --git a/src/algorithm/bluesteins_algorithm.rs b/src/algorithm/bluesteins_algorithm.rs index 9cc3642c..315aa390 100644 --- a/src/algorithm/bluesteins_algorithm.rs +++ b/src/algorithm/bluesteins_algorithm.rs @@ -137,9 +137,10 @@ impl BluesteinsAlgorithm { } } - fn perform_fft_out_of_place( + #[inline] + fn perform_fft_immut( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], scratch: &mut [Complex], ) { @@ -179,6 +180,15 @@ impl BluesteinsAlgorithm { *buffer_entry = inner_entry.conj() * twiddle; } } + + fn perform_fft_out_of_place( + &self, + input: &mut [Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, scratch); + } } boilerplate_fft!( BluesteinsAlgorithm, @@ -186,7 +196,9 @@ boilerplate_fft!( |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len() + this.inner_fft.get_inplace_scratch_len(), // in-place scratch len |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len() - + this.inner_fft.get_inplace_scratch_len() // out of place scratch len + + this.inner_fft.get_inplace_scratch_len(), // out of place scratch len + |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len() + + this.inner_fft.get_inplace_scratch_len() // immut scratch len ); #[cfg(test)] diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 12926189..f6672528 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -17,9 +17,10 @@ macro_rules! boilerplate_fft_butterfly { } } impl Fft for $struct_name { - fn process_outofplace_with_scratch( + #[inline] + fn process_immutable_with_scratch( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], ) { @@ -49,6 +50,14 @@ macro_rules! boilerplate_fft_butterfly { fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); } } + fn process_outofplace_with_scratch( + &self, + input: &mut [Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.process_immutable_with_scratch(input, output, _scratch); + } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { if buffer.len() < self.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us @@ -56,7 +65,7 @@ macro_rules! boilerplate_fft_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| unsafe { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| unsafe { self.perform_fft_butterfly(chunk) }); @@ -74,6 +83,10 @@ macro_rules! boilerplate_fft_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -104,13 +117,22 @@ impl Butterfly1 { } } impl Fft for Butterfly1 { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + output.copy_from_slice(input); + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], output: &mut [Complex], _scratch: &mut [Complex], ) { - output.copy_from_slice(&input); + output.copy_from_slice(input); } fn process_with_scratch(&self, _buffer: &mut [Complex], _scratch: &mut [Complex]) {} @@ -122,6 +144,10 @@ impl Fft for Butterfly1 { fn get_outofplace_scratch_len(&self) -> usize { 0 } + + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for Butterfly1 { fn len(&self) -> usize { diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index e0b700cf..7318d3ae 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -45,6 +45,15 @@ impl Dft { 0 } + fn perform_fft_immut( + &self, + signal: &[Complex], + spectrum: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_out_of_place(signal, spectrum, _scratch); + } + fn perform_fft_out_of_place( &self, signal: &[Complex], diff --git a/src/algorithm/good_thomas_algorithm.rs b/src/algorithm/good_thomas_algorithm.rs index 005ba529..53bc6db8 100644 --- a/src/algorithm/good_thomas_algorithm.rs +++ b/src/algorithm/good_thomas_algorithm.rs @@ -50,6 +50,7 @@ pub struct GoodThomasAlgorithm { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, len: usize, direction: FftDirection, @@ -116,6 +117,13 @@ impl GoodThomasAlgorithm { }, height_outofplace_scratch, ); + let height_fft_immut = height_fft.get_immutable_scratch_len(); + let width_fft_immut = width_fft.get_immutable_scratch_len(); + + let immut_scratch_len = 2 * max( + max(height_fft_immut, width_fft_immut), + max(outofplace_scratch_len, inplace_scratch_len), + ); Self { width, @@ -129,6 +137,7 @@ impl GoodThomasAlgorithm { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len, len, direction, @@ -241,6 +250,30 @@ impl GoodThomasAlgorithm { self.reindex_output(scratch, buffer); } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + let (scratch, scratch2) = scratch.split_at_mut(scratch.len() / 2); + // Re-index the input, copying from the input to the output in the process + self.reindex_input(input, output); + + // run FFTs of size `width` + self.width_size_fft.process_with_scratch(output, scratch); + + let scratch = &mut scratch[..output.len()]; + // transpose + transpose::transpose(output, scratch, self.width, self.height); + + // run FFTs of size 'height' + self.height_size_fft.process_with_scratch(scratch, scratch2); + + // Re-index the output, copying from the input to the output in the process + self.reindex_output(scratch, output); + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -279,7 +312,8 @@ boilerplate_fft!( GoodThomasAlgorithm, |this: &GoodThomasAlgorithm<_>| this.len, |this: &GoodThomasAlgorithm<_>| this.inplace_scratch_len, - |this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len + |this: &GoodThomasAlgorithm<_>| this.outofplace_scratch_len, + |this: &GoodThomasAlgorithm<_>| this.immut_scratch_len ); /// Implementation of the Good-Thomas Algorithm, specialized for smaller input sizes @@ -384,6 +418,38 @@ impl GoodThomasAlgorithmSmall { } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // These asserts are for the unsafe blocks down below. we're relying on the optimizer to get rid of this assert + assert_eq!(self.len(), input.len()); + assert_eq!(self.len(), output.len()); + + let (input_map, output_map) = self.input_output_map.split_at(self.len()); + + // copy the input using our reordering mapping + for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) { + *output_element = input[input_index]; + } + + // run FFTs of size `width` + self.width_size_fft.process_with_scratch(output, scratch); + + // transpose + unsafe { array_utils::transpose_small(self.width, self.height, output, scratch) }; + + // run FFTs of size 'height' + self.height_size_fft.process_with_scratch(scratch, output); + + // copy to the output, using our output redordeing mapping + for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) { + output[output_index] = *input_element; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -448,7 +514,8 @@ boilerplate_fft!( GoodThomasAlgorithmSmall, |this: &GoodThomasAlgorithmSmall<_>| this.width * this.height, |this: &GoodThomasAlgorithmSmall<_>| this.len(), - |_| 0 + |_| 0, + |this: &GoodThomasAlgorithmSmall<_>| this.len() ); #[cfg(test)] @@ -532,12 +599,15 @@ mod unit_tests { for &len in &scratch_lengths { for &inplace_scratch in &scratch_lengths { for &outofplace_scratch in &scratch_lengths { - inner_ffts.push(Arc::new(BigScratchAlgorithm { - len, - inplace_scratch, - outofplace_scratch, - direction: FftDirection::Forward, - }) as Arc>); + for &immut_scratch in &scratch_lengths { + inner_ffts.push(Arc::new(BigScratchAlgorithm { + len, + inplace_scratch, + outofplace_scratch, + direction: FftDirection::Forward, + immut_scratch, + }) as Arc>); + } } } } @@ -565,6 +635,16 @@ mod unit_tests { &mut outofplace_output, &mut outofplace_scratch, ); + + let immut_input = vec![Complex::zero(); fft.len()]; + let mut immut_output = vec![Complex::zero(); fft.len()]; + let mut immut_scratch = vec![Complex::zero(); fft.get_immutable_scratch_len()]; + + fft.process_immutable_with_scratch( + &immut_input, + &mut immut_output, + &mut immut_scratch, + ); } } } diff --git a/src/algorithm/mixed_radix.rs b/src/algorithm/mixed_radix.rs index 808ff0fa..33f31d01 100644 --- a/src/algorithm/mixed_radix.rs +++ b/src/algorithm/mixed_radix.rs @@ -44,6 +44,7 @@ pub struct MixedRadix { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, direction: FftDirection, } @@ -103,6 +104,14 @@ impl MixedRadix { width_outofplace_scratch, ); + let height_fft_immut = height_fft.get_immutable_scratch_len(); + let width_fft_immut = width_fft.get_immutable_scratch_len(); + + let immut_scratch_len = 2 * max( + max(height_fft_immut, width_fft_immut), + max(outofplace_scratch_len, inplace_scratch_len), + ); + Self { twiddles: twiddles.into_boxed_slice(), @@ -114,6 +123,7 @@ impl MixedRadix { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len, direction, } @@ -151,6 +161,47 @@ impl MixedRadix { transpose::transpose(scratch, buffer, self.width, self.height); } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch_raw: &mut [Complex], + ) { + // We require twice as much scratch here as perform_fft_out_of_place + // We have this psuedocode: + // ... + // fft(output, scratch) // FFT in place with scratch + // transpose(output, scratch) // Transpose output -> scratch + // fft(scratch, output) // FFT in place using `output` as scratch + // ... + // process_with_scratch can transpose the output into the input variable, saving scratch for just fft scratch + // Since we can't use the input variable, allocate twice as much scratch as process_with_scratch, + // and split them into two scratches we can use + let (scratch, scratch2) = scratch_raw.split_at_mut(scratch_raw.len() / 2); + // SIX STEP FFT: + + // STEP 1: transpose + transpose::transpose(input, output, self.width, self.height); + + // STEP 2: perform FFTs of size `height` + self.height_size_fft.process_with_scratch(output, scratch); + + // STEP 3: Apply twiddle factors + for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { + *element = *element * twiddle; + } + + let scratch2 = &mut scratch2[..output.len()]; + // STEP 4: transpose again + transpose::transpose(output, scratch2, self.height, self.width); + + // STEP 5: perform FFTs of size `width` + self.width_size_fft.process_with_scratch(scratch2, scratch); + + // STEP 6: transpose again + transpose::transpose(scratch2, output, self.width, self.height); + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -196,7 +247,8 @@ boilerplate_fft!( MixedRadix, |this: &MixedRadix<_>| this.twiddles.len(), |this: &MixedRadix<_>| this.inplace_scratch_len, - |this: &MixedRadix<_>| this.outofplace_scratch_len + |this: &MixedRadix<_>| this.outofplace_scratch_len, + |this: &MixedRadix<_>| this.immut_scratch_len ); /// Implementation of the Mixed-Radix FFT algorithm, specialized for smaller input sizes @@ -302,6 +354,34 @@ impl MixedRadixSmall { unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) }; } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // SIX STEP FFT: + // STEP 1: transpose + unsafe { array_utils::transpose_small(self.width, self.height, input, output) }; + + // STEP 2: perform FFTs of size `height` + self.height_size_fft.process_with_scratch(output, scratch); + + // STEP 3: Apply twiddle factors + for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { + *element = *element * twiddle; + } + + // STEP 4: transpose again + unsafe { array_utils::transpose_small(self.height, self.width, output, scratch) }; + + // STEP 5: perform FFTs of size `width` + self.width_size_fft.process_with_scratch(scratch, output); + + // STEP 6: transpose again + unsafe { array_utils::transpose_small(self.width, self.height, scratch, output) }; + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -334,7 +414,8 @@ boilerplate_fft!( MixedRadixSmall, |this: &MixedRadixSmall<_>| this.twiddles.len(), |this: &MixedRadixSmall<_>| this.len(), - |_| 0 + |_| 0, + |this: &MixedRadixSmall<_>| this.len() ); #[cfg(test)] @@ -393,12 +474,15 @@ mod unit_tests { for &len in &scratch_lengths { for &inplace_scratch in &scratch_lengths { for &outofplace_scratch in &scratch_lengths { - inner_ffts.push(Arc::new(BigScratchAlgorithm { - len, - inplace_scratch, - outofplace_scratch, - direction: FftDirection::Forward, - }) as Arc>); + for &immut_scratch in &scratch_lengths { + inner_ffts.push(Arc::new(BigScratchAlgorithm { + len, + inplace_scratch, + outofplace_scratch, + immut_scratch, + direction: FftDirection::Forward, + }) as Arc>); + } } } } @@ -421,6 +505,16 @@ mod unit_tests { &mut outofplace_output, &mut outofplace_scratch, ); + + let immut_input = vec![Complex::zero(); fft.len()]; + let mut immut_output = vec![Complex::zero(); fft.len()]; + let mut immut_scratch = vec![Complex::zero(); fft.get_immutable_scratch_len()]; + + fft.process_immutable_with_scratch( + &immut_input, + &mut immut_output, + &mut immut_scratch, + ); } } } diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index 7f059dd0..5a1198c2 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -50,6 +50,8 @@ pub struct RadersAlgorithm { len: StrengthReducedUsize, inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, + direction: FftDirection, } @@ -100,6 +102,8 @@ impl RadersAlgorithm { } else { required_inner_scratch }; + let inplace_scratch_len = inner_fft_len + extra_inner_scratch; + let immut_scratch_len = inplace_scratch_len * 2; //precompute a FFT of our reordered twiddle factors let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch]; @@ -115,10 +119,63 @@ impl RadersAlgorithm { len: reduced_len, inplace_scratch_len: inner_fft_len + extra_inner_scratch, outofplace_scratch_len: extra_inner_scratch, + immut_scratch_len, direction, } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // The first output element is just the sum of all the input elements, and we need to store off the first input value + let (output_first, output) = output.split_first_mut().unwrap(); + let (input_first, input) = input.split_first().unwrap(); + let (scratch, scratch2) = scratch.split_at_mut(output.len()); + + // copy the input into the output, reordering as we go. also compute a sum of all elements + let mut input_index = 1; + for output_element in scratch.iter_mut() { + input_index = (input_index * self.primitive_root) % self.len; + + let input_element = input[input_index - 1]; + *output_element = input_element; + } + + self.inner_fft.process_with_scratch(scratch, scratch2); + + // output[0] now contains the sum of elements 1..len. We need the sum of all elements, so all we have to do is add the first input + *output_first = *input_first + scratch[0]; + + // let scratch = &mut scratch[..output.len()]; + + // multiply the inner result with our cached setup data + // also conjugate every entry. this sets us up to do an inverse FFT + // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) + for ((output_cell, scratch_cell), &multiple) in scratch + .iter() + .zip(scratch2.iter_mut()) + .zip(self.inner_fft_data.iter()) + { + *scratch_cell = (*output_cell * multiple).conj(); + } + + // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. + // Of course, we have to conjugate it, just like we conjugated the complex multiplied above + scratch2[0] = scratch2[0] + input_first.conj(); + + self.inner_fft.process_with_scratch(scratch2, scratch); + + // copy the final values into the output, reordering as we go + let mut output_index = 1; + for input_element in scratch2 { + output_index = (output_index * self.primitive_root_inverse) % self.len; + output[output_index - 1] = input_element.conj(); + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -232,7 +289,8 @@ boilerplate_fft!( RadersAlgorithm, |this: &RadersAlgorithm<_>| this.len.get(), |this: &RadersAlgorithm<_>| this.inplace_scratch_len, - |this: &RadersAlgorithm<_>| this.outofplace_scratch_len + |this: &RadersAlgorithm<_>| this.outofplace_scratch_len, + |this: &RadersAlgorithm<_>| this.immut_scratch_len ); #[cfg(test)] diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index d392f8c1..411a44c4 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -118,9 +118,10 @@ impl Radix3 { self.outofplace_scratch_len } - fn perform_fft_out_of_place( + #[inline] + fn perform_fft_immut( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], scratch: &mut [Complex], ) { @@ -132,8 +133,7 @@ impl Radix3 { } // Base-level FFTs - let base_scratch = if scratch.len() > 0 { scratch } else { input }; - self.base_fft.process_with_scratch(output, base_scratch); + self.base_fft.process_with_scratch(output, scratch); // cross-FFTs const ROW_COUNT: usize = 3; @@ -153,6 +153,15 @@ impl Radix3 { layer_twiddles = &layer_twiddles[twiddle_offset..]; } } + + fn perform_fft_out_of_place( + &self, + input: &mut [Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, scratch); + } } boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len); diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 33a804e4..d9bf8af0 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -124,6 +124,43 @@ impl Radix4 { self.outofplace_scratch_len } + #[inline] + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // copy the data into the output vector + if self.len() == self.base_len { + output.copy_from_slice(input); + } else { + bitreversed_transpose::, 4>(self.base_len, input, output); + } + + self.base_fft.process_with_scratch(output, scratch); + + // cross-FFTs + const ROW_COUNT: usize = 4; + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + let butterfly4 = Butterfly4::new(self.direction); + + while cross_fft_len < output.len() { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = num_columns * (ROW_COUNT - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index df561811..80189936 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -158,6 +158,89 @@ impl RadixN { self.outofplace_scratch_len } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if let Some(unroll_factor) = self.factors.first() { + // for performance, we really, really want to unroll the transpose, but we need to make sure the output length is divisible by the unroll amount + // choosing the first factor seems to reliably perform well + match unroll_factor.factor { + RadixFactor::Factor2 => { + factor_transpose::, 2>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor3 => { + factor_transpose::, 3>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor4 => { + factor_transpose::, 4>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor5 => { + factor_transpose::, 5>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor6 => { + factor_transpose::, 6>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor7 => { + factor_transpose::, 7>(self.base_len, input, output, &self.factors) + } + } + } else { + // no factors, so just pass data straight to our base + output.copy_from_slice(input); + } + + // Base-level FFTs + self.base_fft.process_with_scratch(output, scratch); + + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + for factor in self.butterflies.iter() { + let cross_fft_columns = cross_fft_len; + cross_fft_len *= factor.radix(); + + match factor { + InternalRadixFactor::Factor2(butterfly2) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_2(data, layer_twiddles, cross_fft_columns, butterfly2) } + } + } + InternalRadixFactor::Factor3(butterfly3) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_3(data, layer_twiddles, cross_fft_columns, butterfly3) } + } + } + InternalRadixFactor::Factor4(butterfly4) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_4(data, layer_twiddles, cross_fft_columns, butterfly4) } + } + } + InternalRadixFactor::Factor5(butterfly5) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_5(data, layer_twiddles, cross_fft_columns, butterfly5) } + } + } + InternalRadixFactor::Factor6(butterfly6) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_6(data, layer_twiddles, cross_fft_columns, butterfly6) } + } + } + InternalRadixFactor::Factor7(butterfly7) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_7(data, layer_twiddles, cross_fft_columns, butterfly7) } + } + } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = cross_fft_columns * (factor.radix() - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/array_utils.rs b/src/array_utils.rs index 8f0fe3c2..40589873 100644 --- a/src/array_utils.rs +++ b/src/array_utils.rs @@ -145,7 +145,7 @@ mod unit_tests { // Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations // Returns Ok() if every element ended up in a chunk, Err() if there was a remainder -pub fn iter_chunks( +pub fn iter_chunks_mut( mut buffer: &mut [T], chunk_size: usize, mut chunk_fn: impl FnMut(&mut [T]), @@ -169,20 +169,62 @@ pub fn iter_chunks( // Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations // Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder pub fn iter_chunks_zipped( + mut buffer1: &[T], + mut buffer2: &mut [T], + chunk_size: usize, + mut chunk_fn: impl FnMut(&[T], &mut [T]), +) -> Result<(), ()> { + // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size + let uneven = match buffer1.len().cmp(&buffer2.len()) { + std::cmp::Ordering::Less => { + buffer2 = &mut buffer2[..buffer1.len()]; + true + } + std::cmp::Ordering::Equal => false, + std::cmp::Ordering::Greater => { + buffer1 = &buffer1[..buffer2.len()]; + true + } + }; + + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size { + let (head1, tail1) = buffer1.split_at(chunk_size); + buffer1 = tail1; + + let (head2, tail2) = buffer2.split_at_mut(chunk_size); + buffer2 = tail2; + + chunk_fn(head1, head2); + } + + // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder + if !uneven && buffer1.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations +// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder +pub fn iter_chunks_zipped_mut( mut buffer1: &mut [T], mut buffer2: &mut [T], chunk_size: usize, mut chunk_fn: impl FnMut(&mut [T], &mut [T]), ) -> Result<(), ()> { // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size - let uneven = if buffer1.len() > buffer2.len() { - buffer1 = &mut buffer1[..buffer2.len()]; - true - } else if buffer2.len() < buffer1.len() { - buffer2 = &mut buffer2[..buffer1.len()]; - true - } else { - false + let uneven = match buffer1.len().cmp(&buffer2.len()) { + std::cmp::Ordering::Less => { + buffer2 = &mut buffer2[..buffer1.len()]; + true + } + std::cmp::Ordering::Equal => false, + std::cmp::Ordering::Greater => { + buffer1 = &mut buffer1[..buffer2.len()]; + true + } }; // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index 87f9c10a..a42449c1 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -5,8 +5,8 @@ use std::mem::MaybeUninit; use num_complex::Complex; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; +use crate::array_utils::*; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -37,9 +37,9 @@ macro_rules! boilerplate_fft_simd_butterfly { } impl Fft for $struct_name { - fn process_outofplace_with_scratch( + fn process_immutable_with_scratch( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], ) { @@ -56,7 +56,7 @@ macro_rules! boilerplate_fft_simd_butterfly { |in_chunk, out_chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_f32(DoubleBuf { input: input_slice, @@ -72,6 +72,14 @@ macro_rules! boilerplate_fft_simd_butterfly { fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); } } + fn process_outofplace_with_scratch( + &self, + input: &mut [Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.process_immutable_with_scratch(input, output, _scratch); + } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { if buffer.len() < self.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us @@ -79,7 +87,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices self.perform_fft_f32(workaround_transmute_mut::<_, Complex>(chunk)); @@ -100,6 +108,10 @@ macro_rules! boilerplate_fft_simd_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -156,10 +168,11 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } #[inline] - fn perform_fft_out_of_place( + fn perform_fft_immut( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], + _scratch: &mut [Complex], ) { // Perform the column FFTs // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available @@ -169,8 +182,51 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available unsafe { self.row_butterflies(output) }; } + + #[inline] + fn perform_fft_out_of_place( + &self, + input: &mut [Complex], + output: &mut [Complex], + ) { + self.perform_fft_immut(input, output, &mut []); + } } impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = + unsafe { array_utils::workaround_transmute(input) }; + let transmuted_output: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(output) }; + let transmuted_scratch: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(scratch) }; + let result = array_utils::iter_chunks_zipped( + transmuted_input, + transmuted_output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_immut(in_chunk, out_chunk, transmuted_scratch) + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -188,7 +244,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( transmuted_input, transmuted_output, self.len(), @@ -216,7 +272,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(buffer) }; let transmuted_scratch: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, transmuted_scratch) }); @@ -234,6 +290,10 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index c15d5459..6ca02e6c 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -5,8 +5,8 @@ use std::mem::MaybeUninit; use num_complex::Complex; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; +use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -35,9 +35,9 @@ macro_rules! boilerplate_fft_simd_butterfly { } } impl Fft for $struct_name { - fn process_outofplace_with_scratch( + fn process_immutable_with_scratch( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], ) { @@ -54,7 +54,7 @@ macro_rules! boilerplate_fft_simd_butterfly { |in_chunk, out_chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_f64(DoubleBuf { input: input_slice, @@ -70,6 +70,14 @@ macro_rules! boilerplate_fft_simd_butterfly { fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); } } + fn process_outofplace_with_scratch( + &self, + input: &mut [Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.process_immutable_with_scratch(input, output, scratch); + } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { if buffer.len() < self.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us @@ -77,7 +85,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices self.perform_fft_f64(workaround_transmute_mut::<_, Complex>(chunk)); @@ -98,6 +106,10 @@ macro_rules! boilerplate_fft_simd_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -155,6 +167,22 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { }; } + #[inline] + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + // Perform the column FFTs + // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available + unsafe { self.column_butterflies_and_transpose(input, output) }; + + // process the row FFTs in-place in the output buffer + // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available + unsafe { self.row_butterflies(output) }; + } + #[inline] fn perform_fft_out_of_place( &self, @@ -171,6 +199,40 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } } impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = + unsafe { array_utils::workaround_transmute(input) }; + let transmuted_output: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(output) }; + let transmuted_scratch: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(scratch) }; + let result = array_utils::iter_chunks_zipped( + transmuted_input, + transmuted_output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_immut(in_chunk, out_chunk, transmuted_scratch) + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -188,7 +250,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( transmuted_input, transmuted_output, self.len(), @@ -216,7 +278,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(buffer) }; let transmuted_scratch: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, transmuted_scratch) }); @@ -234,6 +296,10 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index 30e8f8e1..b0cf8b12 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -144,6 +144,7 @@ impl BluesteinsAvx { inplace_scratch_len: required_scratch, outofplace_scratch_len: required_scratch, + immut_scratch_len: required_scratch, direction, }, @@ -311,6 +312,61 @@ impl BluesteinsAvx { } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + let (inner_input, inner_scratch) = scratch + .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR); + + // do the necessary setup for bluestein's algorithm: copy the data to the inner buffers, apply some twiddle factors, zero out the rest of the inner buffer + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + + self.prepare_bluesteins(transmuted_input, transmuted_inner_input) + } + + // run our inner forward FFT + self.common_data + .inner_fft + .process_with_scratch(inner_input, inner_scratch); + + // Multiply our inner FFT output by our precomputed data. Then, conjugate the result to set up for an inverse FFT. + // We can conjugate the result of multiplication by conjugating both inputs. We pre-conjugated the multiplier array, + // so we just need to conjugate inner_input, which the pairwise_complex_multiply_conjugated function will handle + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + + Self::pairwise_complex_multiply_conjugated( + transmuted_inner_input, + &self.inner_fft_multiplier, + ) + }; + + // inverse FFT. we're computing a forward but we're massaging it into an inverse by conjugating the inputs and outputs + self.common_data + .inner_fft + .process_with_scratch(inner_input, inner_scratch); + + // finalize the result + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + + self.finalize_bluesteins(transmuted_inner_input, transmuted_output) + } + } + fn perform_fft_out_of_place( &self, input: &[Complex], diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index ed6de0fc..1887bdb2 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -69,6 +69,44 @@ macro_rules! boilerplate_mixedradix { } } + #[inline] + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // Perform the column FFTs + // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available + let (scratch, scratch2) = scratch.split_at_mut(scratch.len() / 2); + let scratch = &mut scratch[..input.len()]; + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch); + + self.perform_column_butterflies_immut(transmuted_input, transmuted_output); + } + + // process the row FFTs. If extra scratch was provided, pass it in. Otherwise, use the output. + self.common_data + .inner_fft + .process_with_scratch(scratch, scratch2); + + // Transpose + // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + + self.transpose(transmuted_input, transmuted_output) + } + } + #[inline] fn perform_fft_out_of_place( &self, @@ -138,11 +176,13 @@ macro_rules! mixedradix_gen_data { let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len(); let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len(); + let immut_scratch_len = usize::max($inner_fft.get_immutable_scratch_len() * 2, len * 2); CommonSimdData { twiddles: twiddles.into_boxed_slice(), inplace_scratch_len: len + inner_outofplace_scratch, outofplace_scratch_len: if inner_inplace_scratch > len { inner_inplace_scratch } else { 0 }, + immut_scratch_len, inner_fft: $inner_fft, len, direction, @@ -152,6 +192,143 @@ macro_rules! mixedradix_gen_data { macro_rules! mixedradix_column_butterflies { ($row_count: expr, $butterfly_fn: expr, $butterfly_fn_lo: expr) => { + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_column_butterflies_immut( + &self, + input: impl AvxArray, + mut buffer: impl AvxArrayMut, + ) { + // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc + const ROW_COUNT: usize = $row_count; + const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1; + + let len_per_row = self.len() / ROW_COUNT; + let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR; + + // If we don't have any column FFTs, we need to move the + // input into the output buffer for the remainder processing + if chunk_count == 0 { + let simd_ops = self.len() / 4; + for i in (0..simd_ops).step_by(4) { + buffer.store_complex(input.load_complex(i), i); + } + } else { + // process the column FFTs + for (c, twiddle_chunk) in self + .common_data + .twiddles + .chunks_exact(TWIDDLES_PER_COLUMN) + .take(chunk_count) + .enumerate() + { + let index_base = c * A::VectorType::COMPLEX_PER_VECTOR; + + // Load columns from the input into registers + let mut columns = [AvxVector::zero(); ROW_COUNT]; + for i in 0..ROW_COUNT { + columns[i] = input.load_complex(index_base + len_per_row * i); + } + + // apply our butterfly function down the columns + let output = $butterfly_fn(columns, self); + + // always write the first row directly back without twiddles + buffer.store_complex(output[0], index_base); + + // for every other row, apply twiddle factors and then write back to memory + for i in 1..ROW_COUNT { + let twiddle = twiddle_chunk[i - 1]; + let output = AvxVector::mul_complex(twiddle, output[i]); + buffer.store_complex(output, index_base + len_per_row * i); + } + } + } + + // finally, we might have a remainder chunk + // Normally, we can fit COMPLEX_PER_VECTOR complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns + let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR; + if partial_remainder > 0 { + // We need to copy the partial remainders to the buffer + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + match partial_remainder { + 1 => buffer.store_partial1_complex(input.load_partial1_complex(cs), cs), + 2 => buffer.store_partial2_complex(input.load_partial2_complex(cs), cs), + 3 => buffer.store_partial3_complex(input.load_partial3_complex(cs), cs), + _ => unreachable!(), + } + } + let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR; + let partial_remainder_twiddle_base = + self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN; + let final_twiddle_chunk = + &self.common_data.twiddles[partial_remainder_twiddle_base..]; + + if partial_remainder > 2 { + // Load 3 columns into full AVX vectors to process our remainder + let mut columns = [AvxVector::zero(); ROW_COUNT]; + for i in 0..ROW_COUNT { + columns[i] = + buffer.load_partial3_complex(partial_remainder_base + len_per_row * i); + } + + // apply our butterfly function down the columns + let mid = $butterfly_fn(columns, self); + + // always write the first row without twiddles + buffer.store_partial3_complex(mid[0], partial_remainder_base); + + // for the remaining rows, apply twiddle factors and then write back to memory + for i in 1..ROW_COUNT { + let twiddle = final_twiddle_chunk[i - 1]; + let output = AvxVector::mul_complex(twiddle, mid[i]); + buffer.store_partial3_complex( + output, + partial_remainder_base + len_per_row * i, + ); + } + } else { + // Load 1 or 2 columns into half vectors to process our remainder. Thankfully, the compiler is smart enough to eliminate this branch on f64, since the partial remainder can only possibly be 1 + let mut columns = [AvxVector::zero(); ROW_COUNT]; + if partial_remainder == 1 { + for i in 0..ROW_COUNT { + columns[i] = buffer + .load_partial1_complex(partial_remainder_base + len_per_row * i); + } + } else { + for i in 0..ROW_COUNT { + columns[i] = buffer + .load_partial2_complex(partial_remainder_base + len_per_row * i); + } + } + + // apply our butterfly function down the columns + let mut mid = $butterfly_fn_lo(columns, self); + + // apply twiddle factors + for i in 1..ROW_COUNT { + mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]); + } + + // store output + if partial_remainder == 1 { + for i in 0..ROW_COUNT { + buffer.store_partial1_complex( + mid[i], + partial_remainder_base + len_per_row * i, + ); + } + } else { + for i in 0..ROW_COUNT { + buffer.store_partial2_complex( + mid[i], + partial_remainder_base + len_per_row * i, + ); + } + } + } + } + } #[target_feature(enable = "avx", enable = "fma")] unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut) { // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc @@ -862,6 +1039,125 @@ impl MixedRadix16xnAvx { } } } + #[target_feature(enable = "avx", enable = "fma")] + unsafe fn perform_column_butterflies_immut( + &self, + input: impl AvxArray, + mut buffer: impl AvxArrayMut, + ) { + // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc + const ROW_COUNT: usize = 16; + const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1; + + let len_per_row = self.len() / ROW_COUNT; + let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR; + + // process the column FFTs + for (c, twiddle_chunk) in self + .common_data + .twiddles + .chunks_exact(TWIDDLES_PER_COLUMN) + .take(chunk_count) + .enumerate() + { + let index_base = c * A::VectorType::COMPLEX_PER_VECTOR; + + column_butterfly16_loadfn!( + |index| input.load_complex(index_base + len_per_row * index), + |mut data, index| { + if index > 0 { + data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]); + } + buffer.store_complex(data, index_base + len_per_row * index) + }, + self.twiddles_butterfly16, + self.twiddles_butterfly4 + ); + } + + // finally, we might have a single partial chunk. + // Normally, we can fit 4 complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns + let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR; + if partial_remainder > 0 { + // We need to copy the partial remainders to the buffer + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + match partial_remainder { + 1 => buffer.store_partial1_complex(input.load_partial1_complex(cs), cs), + 2 => buffer.store_partial2_complex(input.load_partial2_complex(cs), cs), + 3 => buffer.store_partial3_complex(input.load_partial3_complex(cs), cs), + _ => unreachable!(), + } + } + let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR; + let partial_remainder_twiddle_base = + self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN; + let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..]; + + match partial_remainder { + 1 => { + column_butterfly16_loadfn!( + |index| buffer + .load_partial1_complex(partial_remainder_base + len_per_row * index), + |mut data, index| { + if index > 0 { + let twiddle: A::VectorType = final_twiddle_chunk[index - 1]; + data = AvxVector::mul_complex(data, twiddle.lo()); + } + buffer.store_partial1_complex( + data, + partial_remainder_base + len_per_row * index, + ) + }, + [ + self.twiddles_butterfly16[0].lo(), + self.twiddles_butterfly16[1].lo() + ], + self.twiddles_butterfly4.lo() + ); + } + 2 => { + column_butterfly16_loadfn!( + |index| buffer + .load_partial2_complex(partial_remainder_base + len_per_row * index), + |mut data, index| { + if index > 0 { + let twiddle: A::VectorType = final_twiddle_chunk[index - 1]; + data = AvxVector::mul_complex(data, twiddle.lo()); + } + buffer.store_partial2_complex( + data, + partial_remainder_base + len_per_row * index, + ) + }, + [ + self.twiddles_butterfly16[0].lo(), + self.twiddles_butterfly16[1].lo() + ], + self.twiddles_butterfly4.lo() + ); + } + 3 => { + column_butterfly16_loadfn!( + |index| buffer + .load_partial3_complex(partial_remainder_base + len_per_row * index), + |mut data, index| { + if index > 0 { + data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]); + } + buffer.store_partial3_complex( + data, + partial_remainder_base + len_per_row * index, + ) + }, + self.twiddles_butterfly16, + self.twiddles_butterfly4 + ); + } + _ => unreachable!(), + } + } + } mixedradix_transpose!(16, AvxVector::transpose16_packed, AvxVector::transpose16_packed, diff --git a/src/avx/avx_raders.rs b/src/avx/avx_raders.rs index 34b76b1d..b3156c5e 100644 --- a/src/avx/avx_raders.rs +++ b/src/avx/avx_raders.rs @@ -107,6 +107,7 @@ pub struct RadersAvx2 { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, direction: FftDirection, _phantom: std::marker::PhantomData, @@ -267,6 +268,7 @@ impl RadersAvx2 { }) .collect::>() }; + let inplace_scratch_len = len + extra_inner_scratch; Self { input_index_multiplier, input_index_init, @@ -278,8 +280,9 @@ impl RadersAvx2 { len, - inplace_scratch_len: len + extra_inner_scratch, + inplace_scratch_len, outofplace_scratch_len: extra_inner_scratch, + immut_scratch_len: inplace_scratch_len * 2, direction, _phantom: std::marker::PhantomData, @@ -351,6 +354,66 @@ impl RadersAvx2 { } } + fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + self.prepare_raders(transmuted_input, transmuted_output) + } + + let (first_input, _) = input.split_first().unwrap(); + let (first_output, inner_output) = output.split_first_mut().unwrap(); + + let (scratch, scratch2) = scratch.split_at_mut(scratch.len() / 2); + let (_, scratch_inner) = scratch2.split_first_mut().unwrap(); + + // perform the first of two inner FFTs + self.inner_fft.process_with_scratch(inner_output, scratch); + + // inner_output[0] now contains the sum of elements 1..n. we want the sum of all inputs, so all we need to do is add the first input + *first_output = inner_output[0] + *first_input; + + // multiply the inner result with our cached setup data + // also conjugate every entry. this sets us up to do an inverse FFT + // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch_inner); + let transmuted_inner_output: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_output); + avx_vector::pairwise_complex_mul_conjugated( + transmuted_inner_output, + transmuted_inner_input, + &self.twiddles, + ) + }; + + // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. + // Of course, we have to conjugate it, just like we conjugated the complex multiplied above + scratch_inner[0] = scratch_inner[0] + first_input.conj(); + + // execute the second FFT + self.inner_fft.process_with_scratch(scratch_inner, scratch); + + // copy the final values into the output, reordering as we go + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &mut [Complex] = + array_utils::workaround_transmute_mut(scratch2); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + self.finalize_raders(transmuted_input, transmuted_output); + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -479,7 +542,8 @@ boilerplate_avx_fft!( RadersAvx2, |this: &RadersAvx2<_, _>| this.len, |this: &RadersAvx2<_, _>| this.inplace_scratch_len, - |this: &RadersAvx2<_, _>| this.outofplace_scratch_len + |this: &RadersAvx2<_, _>| this.outofplace_scratch_len, + |this: &RadersAvx2<_, _>| this.immut_scratch_len ); #[cfg(test)] diff --git a/src/avx/avx_vector.rs b/src/avx/avx_vector.rs index c5f4c945..97807f2f 100644 --- a/src/avx/avx_vector.rs +++ b/src/avx/avx_vector.rs @@ -2265,7 +2265,7 @@ pub unsafe fn pairwise_complex_mul_conjugated( multiplier.len(), input.len() ); // Assert to convince the compiler to omit bounds checks inside the loop - assert!(input.len() == output.len()); // Assert to convince the compiler to omit bounds checks inside the loop + assert_eq!(input.len(), output.len()); // Assert to convince the compiler to omit bounds checks inside the loop let main_loop_count = input.len() / T::VectorType::COMPLEX_PER_VECTOR; let remainder_count = input.len() % T::VectorType::COMPLEX_PER_VECTOR; diff --git a/src/avx/mod.rs b/src/avx/mod.rs index dc3dbeae..6539daca 100644 --- a/src/avx/mod.rs +++ b/src/avx/mod.rs @@ -23,13 +23,57 @@ struct CommonSimdData { inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, direction: FftDirection, } macro_rules! boilerplate_avx_fft { - ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { + ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr, $immut_scratch_len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + let required_scratch = self.get_immutable_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_outofplace_scratch_len(), + scratch.len(), + ) + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -53,7 +97,7 @@ macro_rules! boilerplate_avx_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -88,7 +132,7 @@ macro_rules! boilerplate_avx_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -111,6 +155,10 @@ macro_rules! boilerplate_avx_fft { fn get_outofplace_scratch_len(&self) -> usize { $out_of_place_scratch_len_fn(self) } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + $immut_scratch_len_fn(self) + } } impl Length for $struct_name { #[inline(always)] @@ -130,6 +178,52 @@ macro_rules! boilerplate_avx_fft { macro_rules! boilerplate_avx_fft_commondata { ($struct_name:ident) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.common_data.immut_scratch_len; + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_immutable_scratch_len(), + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_outofplace_scratch_len(), + scratch.len(), + ); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -157,7 +251,7 @@ macro_rules! boilerplate_avx_fft_commondata { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -196,7 +290,7 @@ macro_rules! boilerplate_avx_fft_commondata { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -219,6 +313,10 @@ macro_rules! boilerplate_avx_fft_commondata { fn get_outofplace_scratch_len(&self) -> usize { self.common_data.outofplace_scratch_len } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + self.common_data.immut_scratch_len + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/common.rs b/src/common.rs index 52eb5e5b..18725de7 100644 --- a/src/common.rs +++ b/src/common.rs @@ -73,6 +73,45 @@ pub fn fft_error_outofplace( macro_rules! boilerplate_fft_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.get_immutable_scratch_len(); + if input.len() < self.len() + || output.len() != input.len() + || scratch.len() < required_scratch + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -99,7 +138,7 @@ macro_rules! boilerplate_fft_oop { return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -132,7 +171,7 @@ macro_rules! boilerplate_fft_oop { } let (scratch, extra_scratch) = scratch.split_at_mut(self.len()); - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, extra_scratch); chunk.copy_from_slice(scratch); }); @@ -156,6 +195,10 @@ macro_rules! boilerplate_fft_oop { fn get_outofplace_scratch_len(&self) -> usize { self.outofplace_scratch_len() } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + self.inplace_scratch_len() + } } impl Length for $struct_name { #[inline(always)] @@ -173,8 +216,49 @@ macro_rules! boilerplate_fft_oop { } macro_rules! boilerplate_fft { - ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { + ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr, $immut_scratch_len:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + let required_scratch = self.get_immutable_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, scratch), + ); + if result.is_err() { + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -202,7 +286,7 @@ macro_rules! boilerplate_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -241,7 +325,7 @@ macro_rules! boilerplate_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -264,6 +348,10 @@ macro_rules! boilerplate_fft { fn get_outofplace_scratch_len(&self) -> usize { $out_of_place_scratch_len_fn(self) } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + $immut_scratch_len(self) + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/lib.rs b/src/lib.rs index a897d42d..f304c15d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -234,6 +234,25 @@ pub trait Fft: Length + Direction + Sync + Send { scratch: &mut [Complex], ); + /// Divides `input` and `output` into chunks of `self.len()`, and computes a FFT on each chunk while + /// keeping `input` untouched. + /// + /// This method uses the `scratch` buffer as scratch space, so the contents should be considered garbage after calling. + /// + /// # Panics + /// + /// This method panics if: + /// - `output.len() ! input.len()` + /// - `input.len() % self.len() > 0` + /// - `input.len() < self.len()` + /// - `scratch.len() < get_immutable_scratch_len()` + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ); + /// Returns the size of the scratch buffer required by `process_with_scratch` /// /// For most FFT sizes, this method will return `self.len()`. For a few small sizes it will return 0, and for some special FFT sizes @@ -247,6 +266,13 @@ pub trait Fft: Length + Direction + Sync + Send { /// (Sizes that require the use of Bluestein's Algorithm), this may return a scratch size larger than `self.len()`. /// The returned value may change from one version of RustFFT to the next. fn get_outofplace_scratch_len(&self) -> usize; + + /// Returns the size of the scratch buffer required by `process_immutable_with_scratch` + /// + /// For most FFT sizes, this method will return `self.len()`. For a few small sizes it will return 0, and for some special FFT sizes + /// (Sizes that require the use of Bluestein's Algorithm), this may return a scratch size larger than `self.len()`. + /// The returned value may change from one version of RustFFT to the next. + fn get_immutable_scratch_len(&self) -> usize; } // Algorithms implemented to use AVX instructions. Only compiled on x86_64, and only compiled if the "avx" feature flag is set. diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index dc84ffa3..9070ad3b 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -43,7 +43,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -59,7 +59,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); - let alldone = array_utils::iter_chunks_zipped( + let alldone = array_utils::iter_chunks_zipped_mut( input, output, 2 * self.len(), @@ -82,6 +82,36 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { } Ok(()) } + + pub(crate) unsafe fn perform_oop_fft_butterfly_multi_immut( + &self, + input: &[Complex], + output: &mut [Complex], + ) -> Result<(), ()> { + let len = input.len(); + let alldone = array_utils::iter_chunks_zipped( + input, + output, + 2 * self.len(), + |in_chunk, out_chunk| { + let input_slice = crate::array_utils::workaround_transmute(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_parallel_fft_contiguous(DoubleBuf { + input: input_slice, + output: output_slice, + }) + }, + ); + if alldone.is_err() && input.len() >= self.len() { + let input_slice = crate::array_utils::workaround_transmute(input); + let output_slice = workaround_transmute_mut(output); + self.perform_fft_contiguous(DoubleBuf { + input: &input_slice[len - self.len()..], + output: &mut output_slice[len - self.len()..], + }) + } + Ok(()) + } } }; } @@ -101,7 +131,7 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -112,9 +142,29 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { &self, input: &mut [Complex], output: &mut [Complex], + ) -> Result<(), ()> { + array_utils::iter_chunks_zipped_mut( + input, + output, + self.len(), + |in_chunk, out_chunk| { + let input_slice = workaround_transmute_mut(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_contiguous(DoubleBuf { + input: input_slice, + output: output_slice, + }) + }, + ) + } + + pub(crate) unsafe fn perform_oop_fft_butterfly_multi_immut( + &self, + input: &[Complex], + output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -130,6 +180,25 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { macro_rules! boilerplate_fft_neon_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi_immut(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -172,6 +241,10 @@ macro_rules! boilerplate_fft_neon_common_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/neon/neon_common.rs b/src/neon/neon_common.rs index 826a320e..28063719 100644 --- a/src/neon/neon_common.rs +++ b/src/neon/neon_common.rs @@ -57,6 +57,37 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_neon_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -74,7 +105,7 @@ macro_rules! boilerplate_fft_neon_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +140,7 @@ macro_rules! boilerplate_fft_neon_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -133,6 +164,10 @@ macro_rules! boilerplate_fft_neon_oop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -180,7 +215,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -219,7 +254,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index 841a782c..f52f6b28 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -83,6 +83,49 @@ impl NeonRadix4 { } } + unsafe fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + // copy the data into the output vector + if self.len() == self.base_len { + output.copy_from_slice(input); + } else { + bitreversed_transpose::, 4>(self.base_len, input, output); + } + + // Base-level FFTs + self.base_fft.process_with_scratch(output, &mut []); + + // cross-FFTs + const ROW_COUNT: usize = 4; + let mut cross_fft_len = self.base_len * ROW_COUNT; + let mut layer_twiddles: &[N::VectorType] = &self.twiddles; + + while cross_fft_len <= input.len() { + let num_rows = input.len() / cross_fft_len; + let num_scalar_columns = cross_fft_len / ROW_COUNT; + let num_vector_columns = num_scalar_columns / N::VectorType::COMPLEX_PER_VECTOR; + + for i in 0..num_rows { + butterfly_4::( + &mut output[i * cross_fft_len..], + layer_twiddles, + num_scalar_columns, + &self.rotation, + ) + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = num_vector_columns * (ROW_COUNT - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + + cross_fft_len *= ROW_COUNT; + } + } + //#[target_feature(enable = "neon")] unsafe fn perform_fft_out_of_place( &self, diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 97c6bcfc..309400b8 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -4,8 +4,8 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; +use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -48,7 +48,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -61,7 +61,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { #[target_feature(enable = "sse4.1")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); @@ -70,7 +70,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { output, 2 * self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_parallel_fft_contiguous(DoubleBuf { input: input_slice, @@ -79,10 +79,10 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { }, ); if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); + let input_slice = crate::array_utils::workaround_transmute(input); let output_slice = workaround_transmute_mut(output); self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], + input: &input_slice[len - self.len()..], output: &mut output_slice[len - self.len()..], }) } @@ -107,7 +107,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -116,11 +116,11 @@ macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { #[target_feature(enable = "sse4.1")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -147,7 +147,7 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -156,11 +156,11 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { #[target_feature(enable = "sse4.1")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -176,6 +176,25 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { macro_rules! boilerplate_fft_sse_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -218,6 +237,10 @@ macro_rules! boilerplate_fft_sse_common_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 0c319955..81c13b7d 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -57,9 +57,9 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_sse_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { - fn process_outofplace_with_scratch( + fn process_immutable_with_scratch( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], ) { @@ -90,6 +90,14 @@ macro_rules! boilerplate_fft_sse_oop { fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); } } + fn process_outofplace_with_scratch( + &self, + input: &mut [Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.process_immutable_with_scratch(input, output, scratch); + } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { if self.len() == 0 { return; @@ -109,7 +117,7 @@ macro_rules! boilerplate_fft_sse_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -133,6 +141,10 @@ macro_rules! boilerplate_fft_sse_oop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] @@ -180,7 +192,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -219,7 +231,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/test_utils.rs b/src/test_utils.rs index 1bbb48bd..d859d884 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -173,6 +173,39 @@ pub fn check_fft_algorithm( ); } } + + // test process_immutable_with_scratch() + { + let mut input = reference_input.clone(); + let mut scratch = vec![Zero::zero(); fft.get_immutable_scratch_len()]; + let mut output = vec![Zero::zero(); n * len]; + + fft.process_immutable_with_scratch(&input, &mut output, &mut scratch); + + assert!( + compare_vectors(&expected_output, &output), + "process_immutable_with_scratch() failed, length = {}, direction = {}", + len, + direction + ); + + // make sure this algorithm works correctly with dirty scratch + if scratch.len() > 0 { + for item in scratch.iter_mut() { + *item = dirty_scratch_value; + } + input.copy_from_slice(&reference_input); + + fft.process_immutable_with_scratch(&input, &mut output, &mut scratch); + + assert!( + compare_vectors(&expected_output, &output), + "process_immutable_with_scratch() failed the 'dirty scratch' test, length = {}, direction = {}", + len, + direction + ); + } + } } // A fake FFT algorithm that requests much more scratch than it needs. You can use this as an inner FFT to other algorithms to test their scratch-supplying logic @@ -182,10 +215,24 @@ pub struct BigScratchAlgorithm { pub inplace_scratch: usize, pub outofplace_scratch: usize, + pub immut_scratch: usize, pub direction: FftDirection, } impl Fft for BigScratchAlgorithm { + fn process_immutable_with_scratch( + &self, + _input: &[Complex], + _output: &mut [Complex], + scratch: &mut [Complex], + ) { + assert!( + scratch.len() >= self.immut_scratch, + "Not enough immut scratch provided, self={:?}, provided scratch={}", + &self, + scratch.len() + ); + } fn process_with_scratch(&self, _buffer: &mut [Complex], scratch: &mut [Complex]) { assert!( scratch.len() >= self.inplace_scratch, @@ -213,6 +260,9 @@ impl Fft for BigScratchAlgorithm { fn get_outofplace_scratch_len(&self) -> usize { self.outofplace_scratch } + fn get_immutable_scratch_len(&self) -> usize { + self.immut_scratch + } } impl Length for BigScratchAlgorithm { fn len(&self) -> usize { diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 90af2e8f..e60aec95 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -46,7 +46,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -59,7 +59,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { #[target_feature(enable = "simd128")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); @@ -68,7 +68,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { output, 2 * self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_parallel_fft_contiguous(DoubleBuf { input: input_slice, @@ -77,10 +77,10 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { }, ); if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); + let input_slice = crate::array_utils::workaround_transmute(input); let output_slice = workaround_transmute_mut(output); self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], + input: &input_slice[len - self.len()..], output: &mut output_slice[len - self.len()..], }) } @@ -105,7 +105,7 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -114,11 +114,11 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { #[target_feature(enable = "simd128")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -134,6 +134,25 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { macro_rules! boilerplate_fft_wasm_simd_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -176,6 +195,10 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index 787170d9..bcab6ad7 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -57,6 +57,37 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_wasm_simd_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_immutable_with_scratch( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -74,7 +105,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +140,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -133,6 +164,10 @@ macro_rules! boilerplate_fft_wasm_simd_oop { fn get_outofplace_scratch_len(&self) -> usize { 0 } + #[inline(always)] + fn get_immutable_scratch_len(&self) -> usize { + self.len() + } } impl Length for $struct_name { #[inline(always)] diff --git a/src/wasm_simd/wasm_simd_radix4.rs b/src/wasm_simd/wasm_simd_radix4.rs index db60d993..f4e17238 100644 --- a/src/wasm_simd/wasm_simd_radix4.rs +++ b/src/wasm_simd/wasm_simd_radix4.rs @@ -82,6 +82,50 @@ impl WasmSimdRadix4 { } } + #[target_feature(enable = "simd128")] + unsafe fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + // copy the data into the output vector + if self.len() == self.base_len { + output.copy_from_slice(input); + } else { + bitreversed_transpose::, 4>(self.base_len, input, output); + } + + // Base-level FFTs + self.base_fft.process_with_scratch(output, &mut []); + + // cross-FFTs + const ROW_COUNT: usize = 4; + let mut cross_fft_len = self.base_len * ROW_COUNT; + let mut layer_twiddles: &[S::VectorType] = &self.twiddles; + + while cross_fft_len <= input.len() { + let num_rows = input.len() / cross_fft_len; + let num_scalar_columns = cross_fft_len / ROW_COUNT; + let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR; + + for i in 0..num_rows { + butterfly_4::( + &mut output[i * cross_fft_len..], + layer_twiddles, + num_scalar_columns, + &self.rotation, + ) + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = num_vector_columns * (ROW_COUNT - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + + cross_fft_len *= ROW_COUNT; + } + } + #[target_feature(enable = "simd128")] unsafe fn perform_fft_out_of_place( &self, diff --git a/tests/test_immutable.rs b/tests/test_immutable.rs new file mode 100644 index 00000000..029eca51 --- /dev/null +++ b/tests/test_immutable.rs @@ -0,0 +1,52 @@ +use num_complex::Complex; +use rustfft::{FftNum, FftPlanner}; + +const TEST_MAX: usize = 1001; + +#[test] +fn immutable_f32() { + for i in 0..TEST_MAX { + let input = vec![Complex::new(7.0, 8.0); i]; + + let mut_output = fft_wrapper_mut::(&input); + let immut_output = fft_wrapper_immut::(&input); + + assert_eq!(mut_output, immut_output, "{}", i); + } +} + +#[test] +fn immutable_f64() { + for i in 0..TEST_MAX { + let input = vec![Complex::new(7.0, 8.0); i]; + + let mut_output = fft_wrapper_mut::(&input); + let immut_output = fft_wrapper_immut::(&input); + + assert_eq!(mut_output, immut_output, "{}", i); + } +} + +fn fft_wrapper_mut(input: &[Complex]) -> Vec> { + let cz = Complex::new(T::zero(), T::zero()); + let mut plan = FftPlanner::::new(); + let p = plan.plan_fft_forward(input.len()); + + let mut scratch = vec![cz; p.get_inplace_scratch_len()]; + let mut output = input.to_vec(); + + p.process_with_scratch(&mut output, &mut scratch); + output +} + +fn fft_wrapper_immut(input: &[Complex]) -> Vec> { + let cz = Complex::new(T::zero(), T::zero()); + let mut plan = FftPlanner::::new(); + let p = plan.plan_fft_forward(input.len()); + + let mut scratch = vec![cz; p.get_immutable_scratch_len()]; + let mut output = vec![cz; input.len()]; + + p.process_immutable_with_scratch(input, &mut output, &mut scratch); + output +} From ce3c9814d43b7537db35199e7aea4610ed2a5e7b Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Thu, 15 May 2025 22:46:23 -0500 Subject: [PATCH 02/26] Starting to addressing pull request comments --- src/algorithm/butterflies.rs | 6 +-- src/algorithm/dft.rs | 18 ++++---- src/algorithm/good_thomas_algorithm.rs | 17 ++++--- src/algorithm/mixed_radix.rs | 33 ++++---------- src/algorithm/raders_algorithm.rs | 23 ++++------ src/algorithm/radix3.rs | 20 ++++----- src/avx/avx32_butterflies.rs | 12 ++--- src/avx/avx64_butterflies.rs | 20 +++------ src/avx/avx_bluesteins.rs | 48 +------------------- src/avx/avx_mixed_radix.rs | 62 +++++++++++--------------- src/avx/mod.rs | 8 ++-- src/common.rs | 39 ++++++++++++++-- src/neon/neon_butterflies.rs | 6 +-- src/neon/neon_common.rs | 4 +- src/sse/sse_butterflies.rs | 4 +- src/sse/sse_common.rs | 4 +- src/wasm_simd/wasm_simd_butterflies.rs | 4 +- src/wasm_simd/wasm_simd_common.rs | 4 +- 18 files changed, 141 insertions(+), 191 deletions(-) diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index f6672528..4371f9eb 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -3,7 +3,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; use crate::array_utils::{self, DoubleBuf, LoadStore}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_inplace, fft_error_outofplace, fft_error_immut}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -26,8 +26,8 @@ macro_rules! boilerplate_fft_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = array_utils::iter_chunks_zipped( diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index 7318d3ae..c0e1b84a 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -50,15 +50,6 @@ impl Dft { signal: &[Complex], spectrum: &mut [Complex], _scratch: &mut [Complex], - ) { - self.perform_fft_out_of_place(signal, spectrum, _scratch); - } - - fn perform_fft_out_of_place( - &self, - signal: &[Complex], - spectrum: &mut [Complex], - _scratch: &mut [Complex], ) { for k in 0..spectrum.len() { let output_cell = spectrum.get_mut(k).unwrap(); @@ -77,6 +68,15 @@ impl Dft { } } } + + fn perform_fft_out_of_place( + &self, + signal: &[Complex], + spectrum: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(signal, spectrum, _scratch); + } } boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len()); diff --git a/src/algorithm/good_thomas_algorithm.rs b/src/algorithm/good_thomas_algorithm.rs index 53bc6db8..db4b591f 100644 --- a/src/algorithm/good_thomas_algorithm.rs +++ b/src/algorithm/good_thomas_algorithm.rs @@ -117,12 +117,10 @@ impl GoodThomasAlgorithm { }, height_outofplace_scratch, ); - let height_fft_immut = height_fft.get_immutable_scratch_len(); - let width_fft_immut = width_fft.get_immutable_scratch_len(); - let immut_scratch_len = 2 * max( - max(height_fft_immut, width_fft_immut), - max(outofplace_scratch_len, inplace_scratch_len), + let immut_scratch_len = max( + width_fft.get_inplace_scratch_len(), + len + height_fft.get_inplace_scratch_len(), ); Self { @@ -256,19 +254,20 @@ impl GoodThomasAlgorithm { output: &mut [Complex], scratch: &mut [Complex], ) { - let (scratch, scratch2) = scratch.split_at_mut(scratch.len() / 2); // Re-index the input, copying from the input to the output in the process self.reindex_input(input, output); // run FFTs of size `width` self.width_size_fft.process_with_scratch(output, scratch); - let scratch = &mut scratch[..output.len()]; + let (scratch, inner_scratch) = scratch.split_at_mut(self.len()); + // transpose transpose::transpose(output, scratch, self.width, self.height); // run FFTs of size 'height' - self.height_size_fft.process_with_scratch(scratch, scratch2); + self.height_size_fft + .process_with_scratch(scratch, inner_scratch); // Re-index the output, copying from the input to the output in the process self.reindex_output(scratch, output); @@ -604,8 +603,8 @@ mod unit_tests { len, inplace_scratch, outofplace_scratch, - direction: FftDirection::Forward, immut_scratch, + direction: FftDirection::Forward, }) as Arc>); } } diff --git a/src/algorithm/mixed_radix.rs b/src/algorithm/mixed_radix.rs index 33f31d01..63cc9030 100644 --- a/src/algorithm/mixed_radix.rs +++ b/src/algorithm/mixed_radix.rs @@ -104,12 +104,9 @@ impl MixedRadix { width_outofplace_scratch, ); - let height_fft_immut = height_fft.get_immutable_scratch_len(); - let width_fft_immut = width_fft.get_immutable_scratch_len(); - - let immut_scratch_len = 2 * max( - max(height_fft_immut, width_fft_immut), - max(outofplace_scratch_len, inplace_scratch_len), + let immut_scratch_len = max( + len + width_fft.get_inplace_scratch_len(), + height_fft.get_inplace_scratch_len(), ); Self { @@ -167,39 +164,27 @@ impl MixedRadix { output: &mut [Complex], scratch_raw: &mut [Complex], ) { - // We require twice as much scratch here as perform_fft_out_of_place - // We have this psuedocode: - // ... - // fft(output, scratch) // FFT in place with scratch - // transpose(output, scratch) // Transpose output -> scratch - // fft(scratch, output) // FFT in place using `output` as scratch - // ... - // process_with_scratch can transpose the output into the input variable, saving scratch for just fft scratch - // Since we can't use the input variable, allocate twice as much scratch as process_with_scratch, - // and split them into two scratches we can use - let (scratch, scratch2) = scratch_raw.split_at_mut(scratch_raw.len() / 2); - // SIX STEP FFT: - // STEP 1: transpose transpose::transpose(input, output, self.width, self.height); // STEP 2: perform FFTs of size `height` - self.height_size_fft.process_with_scratch(output, scratch); + self.height_size_fft.process_with_scratch(output, scratch_raw); // STEP 3: Apply twiddle factors for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { *element = *element * twiddle; } - let scratch2 = &mut scratch2[..output.len()]; + let (scratch, inner_scratch) = scratch_raw.split_at_mut(self.len()); + // STEP 4: transpose again - transpose::transpose(output, scratch2, self.height, self.width); + transpose::transpose(output, scratch, self.height, self.width); // STEP 5: perform FFTs of size `width` - self.width_size_fft.process_with_scratch(scratch2, scratch); + self.width_size_fft.process_with_scratch(scratch, inner_scratch); // STEP 6: transpose again - transpose::transpose(scratch2, output, self.width, self.height); + transpose::transpose(scratch, output, self.width, self.height); } fn perform_fft_out_of_place( diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index 5a1198c2..ee538e11 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -133,7 +133,7 @@ impl RadersAlgorithm { // The first output element is just the sum of all the input elements, and we need to store off the first input value let (output_first, output) = output.split_first_mut().unwrap(); let (input_first, input) = input.split_first().unwrap(); - let (scratch, scratch2) = scratch.split_at_mut(output.len()); + let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1); // copy the input into the output, reordering as we go. also compute a sum of all elements let mut input_index = 1; @@ -144,35 +144,30 @@ impl RadersAlgorithm { *output_element = input_element; } - self.inner_fft.process_with_scratch(scratch, scratch2); + self.inner_fft.process_with_scratch(scratch, extra_scratch); // output[0] now contains the sum of elements 1..len. We need the sum of all elements, so all we have to do is add the first input *output_first = *input_first + scratch[0]; - // let scratch = &mut scratch[..output.len()]; - // multiply the inner result with our cached setup data // also conjugate every entry. this sets us up to do an inverse FFT // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) - for ((output_cell, scratch_cell), &multiple) in scratch - .iter() - .zip(scratch2.iter_mut()) - .zip(self.inner_fft_data.iter()) - { - *scratch_cell = (*output_cell * multiple).conj(); + for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) { + *scratch_cell = (*scratch_cell * twiddle).conj(); } // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. // Of course, we have to conjugate it, just like we conjugated the complex multiplied above - scratch2[0] = scratch2[0] + input_first.conj(); + scratch[0] = scratch[0] + input_first.conj(); - self.inner_fft.process_with_scratch(scratch2, scratch); + // execute the second FFT + self.inner_fft.process_with_scratch(scratch, extra_scratch); // copy the final values into the output, reordering as we go let mut output_index = 1; - for input_element in scratch2 { + for scratch_element in scratch { output_index = (output_index * self.primitive_root_inverse) % self.len; - output[output_index - 1] = input_element.conj(); + output[output_index - 1] = scratch_element.conj(); } } diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index 411a44c4..c04a9a09 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -118,12 +118,21 @@ impl Radix3 { self.outofplace_scratch_len } - #[inline] fn perform_fft_immut( &self, input: &[Complex], output: &mut [Complex], scratch: &mut [Complex], + ) { + self.perform_fft_out_of_place(input, output, scratch); + } + + #[inline] + fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], ) { // copy the data into the output vector if self.len() == self.base_len { @@ -153,15 +162,6 @@ impl Radix3 { layer_twiddles = &layer_twiddles[twiddle_offset..]; } } - - fn perform_fft_out_of_place( - &self, - input: &mut [Complex], - output: &mut [Complex], - scratch: &mut [Complex], - ) { - self.perform_fft_immut(input, output, scratch); - } } boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len); diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index a42449c1..eaaa3b5f 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -6,8 +6,8 @@ use num_complex::Complex; use crate::array_utils; use crate::array_utils::DoubleBuf; -use crate::array_utils::*; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -45,8 +45,8 @@ macro_rules! boilerplate_fft_simd_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = array_utils::iter_chunks_zipped( @@ -201,8 +201,8 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index 6ca02e6c..31479209 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -7,7 +7,7 @@ use num_complex::Complex; use crate::array_utils; use crate::array_utils::DoubleBuf; use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_inplace, fft_error_outofplace, fft_error_immut}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; @@ -43,8 +43,8 @@ macro_rules! boilerplate_fft_simd_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = array_utils::iter_chunks_zipped( @@ -174,19 +174,13 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { output: &mut [Complex], _scratch: &mut [Complex], ) { - // Perform the column FFTs - // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available - unsafe { self.column_butterflies_and_transpose(input, output) }; - - // process the row FFTs in-place in the output buffer - // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available - unsafe { self.row_butterflies(output) }; + self.perform_fft_out_of_place(input, output); } #[inline] fn perform_fft_out_of_place( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) { // Perform the column FFTs @@ -207,8 +201,8 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index b0cf8b12..f04fa1cb 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -318,53 +318,7 @@ impl BluesteinsAvx { output: &mut [Complex], scratch: &mut [Complex], ) { - let (inner_input, inner_scratch) = scratch - .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR); - - // do the necessary setup for bluestein's algorithm: copy the data to the inner buffers, apply some twiddle factors, zero out the rest of the inner buffer - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); - let transmuted_inner_input: &mut [Complex] = - array_utils::workaround_transmute_mut(inner_input); - - self.prepare_bluesteins(transmuted_input, transmuted_inner_input) - } - - // run our inner forward FFT - self.common_data - .inner_fft - .process_with_scratch(inner_input, inner_scratch); - - // Multiply our inner FFT output by our precomputed data. Then, conjugate the result to set up for an inverse FFT. - // We can conjugate the result of multiplication by conjugating both inputs. We pre-conjugated the multiplier array, - // so we just need to conjugate inner_input, which the pairwise_complex_multiply_conjugated function will handle - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_inner_input: &mut [Complex] = - array_utils::workaround_transmute_mut(inner_input); - - Self::pairwise_complex_multiply_conjugated( - transmuted_inner_input, - &self.inner_fft_multiplier, - ) - }; - - // inverse FFT. we're computing a forward but we're massaging it into an inverse by conjugating the inputs and outputs - self.common_data - .inner_fft - .process_with_scratch(inner_input, inner_scratch); - - // finalize the result - unsafe { - // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_output: &mut [Complex] = - array_utils::workaround_transmute_mut(output); - let transmuted_inner_input: &mut [Complex] = - array_utils::workaround_transmute_mut(inner_input); - - self.finalize_bluesteins(transmuted_inner_input, transmuted_output) - } + self.perform_fft_out_of_place(input, output, scratch); } fn perform_fft_out_of_place( diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index 1887bdb2..a7f2a386 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -78,8 +78,7 @@ macro_rules! boilerplate_mixedradix { ) { // Perform the column FFTs // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available - let (scratch, scratch2) = scratch.split_at_mut(scratch.len() / 2); - let scratch = &mut scratch[..input.len()]; + let (scratch, inner_scratch) = scratch.split_at_mut(input.len()); unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); @@ -92,7 +91,7 @@ macro_rules! boilerplate_mixedradix { // process the row FFTs. If extra scratch was provided, pass it in. Otherwise, use the output. self.common_data .inner_fft - .process_with_scratch(scratch, scratch2); + .process_with_scratch(scratch, inner_scratch); // Transpose // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available @@ -176,7 +175,7 @@ macro_rules! mixedradix_gen_data { let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len(); let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len(); - let immut_scratch_len = usize::max($inner_fft.get_immutable_scratch_len() * 2, len * 2); + let immut_scratch_len = len + $inner_fft.get_inplace_scratch_len(); CommonSimdData { twiddles: twiddles.into_boxed_slice(), @@ -205,42 +204,33 @@ macro_rules! mixedradix_column_butterflies { let len_per_row = self.len() / ROW_COUNT; let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR; - // If we don't have any column FFTs, we need to move the - // input into the output buffer for the remainder processing - if chunk_count == 0 { - let simd_ops = self.len() / 4; - for i in (0..simd_ops).step_by(4) { - buffer.store_complex(input.load_complex(i), i); + // process the column FFTs + for (c, twiddle_chunk) in self + .common_data + .twiddles + .chunks_exact(TWIDDLES_PER_COLUMN) + .take(chunk_count) + .enumerate() + { + let index_base = c * A::VectorType::COMPLEX_PER_VECTOR; + + // Load columns from the input into registers + let mut columns = [AvxVector::zero(); ROW_COUNT]; + for i in 0..ROW_COUNT { + columns[i] = input.load_complex(index_base + len_per_row * i); } - } else { - // process the column FFTs - for (c, twiddle_chunk) in self - .common_data - .twiddles - .chunks_exact(TWIDDLES_PER_COLUMN) - .take(chunk_count) - .enumerate() - { - let index_base = c * A::VectorType::COMPLEX_PER_VECTOR; - - // Load columns from the input into registers - let mut columns = [AvxVector::zero(); ROW_COUNT]; - for i in 0..ROW_COUNT { - columns[i] = input.load_complex(index_base + len_per_row * i); - } - // apply our butterfly function down the columns - let output = $butterfly_fn(columns, self); + // apply our butterfly function down the columns + let output = $butterfly_fn(columns, self); - // always write the first row directly back without twiddles - buffer.store_complex(output[0], index_base); + // always write the first row directly back without twiddles + buffer.store_complex(output[0], index_base); - // for every other row, apply twiddle factors and then write back to memory - for i in 1..ROW_COUNT { - let twiddle = twiddle_chunk[i - 1]; - let output = AvxVector::mul_complex(twiddle, output[i]); - buffer.store_complex(output, index_base + len_per_row * i); - } + // for every other row, apply twiddle factors and then write back to memory + for i in 1..ROW_COUNT { + let twiddle = twiddle_chunk[i - 1]; + let output = AvxVector::mul_complex(twiddle, output[i]); + buffer.store_complex(output, index_base + len_per_row * i); } } diff --git a/src/avx/mod.rs b/src/avx/mod.rs index 6539daca..06061574 100644 --- a/src/avx/mod.rs +++ b/src/avx/mod.rs @@ -43,14 +43,14 @@ macro_rules! boilerplate_avx_fft { || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + crate::common::fft_error_immut( self.len(), input.len(), output.len(), required_scratch, scratch.len(), ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let scratch = &mut scratch[..required_scratch]; @@ -194,14 +194,14 @@ macro_rules! boilerplate_avx_fft_commondata { || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + crate::common::fft_error_immut( self.len(), input.len(), output.len(), self.get_immutable_scratch_len(), scratch.len(), ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let scratch = &mut scratch[..required_scratch]; diff --git a/src/common.rs b/src/common.rs index 18725de7..efd1ce1a 100644 --- a/src/common.rs +++ b/src/common.rs @@ -70,6 +70,39 @@ pub fn fft_error_outofplace( ); } +// Prints an error raised by an in-place FFT algorithm's `process_inplace` method +// Marked cold and inline never to keep all formatting code out of the many monomorphized process_inplace methods +#[cold] +#[inline(never)] +pub fn fft_error_immut( + expected_len: usize, + actual_input: usize, + actual_output: usize, + expected_scratch: usize, + actual_scratch: usize, +) { + assert_eq!(actual_input, actual_output, "Provided FFT input buffer and output buffer must have the same length. Got input.len() = {}, output.len() = {}", actual_input, actual_output); + assert!( + actual_input >= expected_len, + "Provided FFT buffer was too small. Expected len = {}, got len = {}", + expected_len, + actual_input + ); + assert_eq!( + actual_input % expected_len, + 0, + "Input FFT buffer must be a multiple of FFT length. Expected multiple of {}, got len = {}", + expected_len, + actual_input + ); + assert!( + actual_scratch >= expected_scratch, + "Not enough scratch space was provided. Expected scratch len >= {}, got scratch len = {}", + expected_scratch, + actual_scratch + ); +} + macro_rules! boilerplate_fft_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { @@ -89,14 +122,14 @@ macro_rules! boilerplate_fft_oop { || scratch.len() < required_scratch { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + crate::common::fft_error_immut( self.len(), input.len(), output.len(), required_scratch, scratch.len(), ); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = array_utils::iter_chunks_zipped( @@ -232,7 +265,7 @@ macro_rules! boilerplate_fft { || input.len() < self.len() || output.len() != input.len() { - fft_error_outofplace( + crate::common::fft_error_immut( self.len(), input.len(), output.len(), diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index 9070ad3b..8cfde00f 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -188,15 +188,15 @@ macro_rules! boilerplate_fft_neon_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi_immut(input, output) }; if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/neon/neon_common.rs b/src/neon/neon_common.rs index 28063719..adeeb86c 100644 --- a/src/neon/neon_common.rs +++ b/src/neon/neon_common.rs @@ -69,8 +69,8 @@ macro_rules! boilerplate_fft_neon_oop { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 309400b8..b6596a8c 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -184,8 +184,8 @@ macro_rules! boilerplate_fft_sse_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 81c13b7d..9ab5ba82 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -69,8 +69,8 @@ macro_rules! boilerplate_fft_sse_oop { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index e60aec95..56c774b7 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -142,8 +142,8 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index bcab6ad7..4f4211ad 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -69,8 +69,8 @@ macro_rules! boilerplate_fft_wasm_simd_oop { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); - return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { From 67a1408b41801032bae42f9c2f592c9fb70770fc Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Mon, 26 May 2025 13:11:19 -0500 Subject: [PATCH 03/26] Undoing fmt on autogen code --- src/algorithm/butterflies.rs | 2 +- src/algorithm/mixed_radix.rs | 6 ++++-- src/algorithm/raders_algorithm.rs | 4 ++-- src/avx/avx64_butterflies.rs | 2 +- src/neon/neon_butterflies.rs | 2 +- src/neon/neon_radix4.rs | 2 +- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 4371f9eb..2a39bca1 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -3,7 +3,7 @@ use num_complex::Complex; use crate::{common::FftNum, FftDirection}; use crate::array_utils::{self, DoubleBuf, LoadStore}; -use crate::common::{fft_error_inplace, fft_error_outofplace, fft_error_immut}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/mixed_radix.rs b/src/algorithm/mixed_radix.rs index 63cc9030..66405eba 100644 --- a/src/algorithm/mixed_radix.rs +++ b/src/algorithm/mixed_radix.rs @@ -168,7 +168,8 @@ impl MixedRadix { transpose::transpose(input, output, self.width, self.height); // STEP 2: perform FFTs of size `height` - self.height_size_fft.process_with_scratch(output, scratch_raw); + self.height_size_fft + .process_with_scratch(output, scratch_raw); // STEP 3: Apply twiddle factors for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { @@ -181,7 +182,8 @@ impl MixedRadix { transpose::transpose(output, scratch, self.height, self.width); // STEP 5: perform FFTs of size `width` - self.width_size_fft.process_with_scratch(scratch, inner_scratch); + self.width_size_fft + .process_with_scratch(scratch, inner_scratch); // STEP 6: transpose again transpose::transpose(scratch, output, self.width, self.height); diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index ee538e11..73382932 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -103,7 +103,7 @@ impl RadersAlgorithm { required_inner_scratch }; let inplace_scratch_len = inner_fft_len + extra_inner_scratch; - let immut_scratch_len = inplace_scratch_len * 2; + let immut_scratch_len = inner_fft_len + required_inner_scratch; //precompute a FFT of our reordered twiddle factors let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch]; @@ -117,7 +117,7 @@ impl RadersAlgorithm { primitive_root_inverse, len: reduced_len, - inplace_scratch_len: inner_fft_len + extra_inner_scratch, + inplace_scratch_len, outofplace_scratch_len: extra_inner_scratch, immut_scratch_len, direction, diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index 31479209..538c8098 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -7,7 +7,7 @@ use num_complex::Complex; use crate::array_utils; use crate::array_utils::DoubleBuf; use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace, fft_error_immut}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; use crate::{Direction, Fft, FftDirection, Length}; diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index 8cfde00f..d985b947 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index f52f6b28..41bcc632 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::sync::Arc; use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; From 0106c16620072aad83acea1c8da27f6b70e80ca4 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Mon, 26 May 2025 13:34:17 -0500 Subject: [PATCH 04/26] More work in progress --- src/avx/avx_mixed_radix.rs | 22 +++++++++------- src/lib.rs | 3 ++- src/neon/neon_radix4.rs | 36 +------------------------- src/wasm_simd/wasm_simd_butterflies.rs | 2 +- src/wasm_simd/wasm_simd_radix4.rs | 36 +------------------------- 5 files changed, 17 insertions(+), 82 deletions(-) diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index a7f2a386..7686aa11 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -1069,16 +1069,6 @@ impl MixedRadix16xnAvx { // Normally, we can fit 4 complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR; if partial_remainder > 0 { - // We need to copy the partial remainders to the buffer - for c in 0..self.len() / len_per_row { - let cs = c * len_per_row + len_per_row - partial_remainder; - match partial_remainder { - 1 => buffer.store_partial1_complex(input.load_partial1_complex(cs), cs), - 2 => buffer.store_partial2_complex(input.load_partial2_complex(cs), cs), - 3 => buffer.store_partial3_complex(input.load_partial3_complex(cs), cs), - _ => unreachable!(), - } - } let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR; let partial_remainder_twiddle_base = self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN; @@ -1086,6 +1076,10 @@ impl MixedRadix16xnAvx { match partial_remainder { 1 => { + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + buffer.store_partial1_complex(input.load_partial1_complex(cs), cs); + } column_butterfly16_loadfn!( |index| buffer .load_partial1_complex(partial_remainder_base + len_per_row * index), @@ -1107,6 +1101,10 @@ impl MixedRadix16xnAvx { ); } 2 => { + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + buffer.store_partial2_complex(input.load_partial2_complex(cs), cs); + } column_butterfly16_loadfn!( |index| buffer .load_partial2_complex(partial_remainder_base + len_per_row * index), @@ -1128,6 +1126,10 @@ impl MixedRadix16xnAvx { ); } 3 => { + for c in 0..self.len() / len_per_row { + let cs = c * len_per_row + len_per_row - partial_remainder; + buffer.store_partial3_complex(input.load_partial3_complex(cs), cs); + } column_butterfly16_loadfn!( |index| buffer .load_partial3_complex(partial_remainder_base + len_per_row * index), diff --git a/src/lib.rs b/src/lib.rs index f304c15d..d88f3304 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -269,7 +269,8 @@ pub trait Fft: Length + Direction + Sync + Send { /// Returns the size of the scratch buffer required by `process_immutable_with_scratch` /// - /// For most FFT sizes, this method will return `self.len()`. For a few small sizes it will return 0, and for some special FFT sizes + /// For most FFT sizes, this method will return something between self.len() and self.len() * 2. + /// For a few small sizes it will return 0, and for some special FFT sizes /// (Sizes that require the use of Bluestein's Algorithm), this may return a scratch size larger than `self.len()`. /// The returned value may change from one version of RustFFT to the next. fn get_immutable_scratch_len(&self) -> usize; diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index 41bcc632..5fd1699f 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -89,41 +89,7 @@ impl NeonRadix4 { output: &mut [Complex], _scratch: &mut [Complex], ) { - // copy the data into the output vector - if self.len() == self.base_len { - output.copy_from_slice(input); - } else { - bitreversed_transpose::, 4>(self.base_len, input, output); - } - - // Base-level FFTs - self.base_fft.process_with_scratch(output, &mut []); - - // cross-FFTs - const ROW_COUNT: usize = 4; - let mut cross_fft_len = self.base_len * ROW_COUNT; - let mut layer_twiddles: &[N::VectorType] = &self.twiddles; - - while cross_fft_len <= input.len() { - let num_rows = input.len() / cross_fft_len; - let num_scalar_columns = cross_fft_len / ROW_COUNT; - let num_vector_columns = num_scalar_columns / N::VectorType::COMPLEX_PER_VECTOR; - - for i in 0..num_rows { - butterfly_4::( - &mut output[i * cross_fft_len..], - layer_twiddles, - num_scalar_columns, - &self.rotation, - ) - } - - // skip past all the twiddle factors used in this layer - let twiddle_offset = num_vector_columns * (ROW_COUNT - 1); - layer_twiddles = &layer_twiddles[twiddle_offset..]; - - cross_fft_len *= ROW_COUNT; - } + self.perform_fft_out_of_place(input, output, _scratch); } //#[target_feature(enable = "neon")] diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 56c774b7..e1ef1c6f 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -150,7 +150,7 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/wasm_simd/wasm_simd_radix4.rs b/src/wasm_simd/wasm_simd_radix4.rs index f4e17238..679ddb92 100644 --- a/src/wasm_simd/wasm_simd_radix4.rs +++ b/src/wasm_simd/wasm_simd_radix4.rs @@ -89,41 +89,7 @@ impl WasmSimdRadix4 { output: &mut [Complex], _scratch: &mut [Complex], ) { - // copy the data into the output vector - if self.len() == self.base_len { - output.copy_from_slice(input); - } else { - bitreversed_transpose::, 4>(self.base_len, input, output); - } - - // Base-level FFTs - self.base_fft.process_with_scratch(output, &mut []); - - // cross-FFTs - const ROW_COUNT: usize = 4; - let mut cross_fft_len = self.base_len * ROW_COUNT; - let mut layer_twiddles: &[S::VectorType] = &self.twiddles; - - while cross_fft_len <= input.len() { - let num_rows = input.len() / cross_fft_len; - let num_scalar_columns = cross_fft_len / ROW_COUNT; - let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR; - - for i in 0..num_rows { - butterfly_4::( - &mut output[i * cross_fft_len..], - layer_twiddles, - num_scalar_columns, - &self.rotation, - ) - } - - // skip past all the twiddle factors used in this layer - let twiddle_offset = num_vector_columns * (ROW_COUNT - 1); - layer_twiddles = &layer_twiddles[twiddle_offset..]; - - cross_fft_len *= ROW_COUNT; - } + self.perform_fft_out_of_place(input, output, _scratch); } #[target_feature(enable = "simd128")] From 0620130ab931d724dc1c1f5cacc7f99f984cfb1d Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Mon, 26 May 2025 14:11:19 -0500 Subject: [PATCH 05/26] Correcting neon imports --- src/neon/mod.rs | 4 ---- src/neon/neon_prime_butterflies.rs | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/neon/mod.rs b/src/neon/mod.rs index ee8039b9..20b936cb 100644 --- a/src/neon/mod.rs +++ b/src/neon/mod.rs @@ -12,10 +12,6 @@ mod neon_utils; pub mod neon_planner; -pub use self::neon_butterflies::*; -pub use self::neon_prime_butterflies::*; -pub use self::neon_radix4::*; - use std::arch::aarch64::{float32x4_t, float64x2_t}; use crate::FftNum; diff --git a/src/neon/neon_prime_butterflies.rs b/src/neon/neon_prime_butterflies.rs index 688b5a20..8ef48108 100644 --- a/src/neon/neon_prime_butterflies.rs +++ b/src/neon/neon_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; From 8d2424c1018490ebf73af0a12c54047e1dd7d3a9 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Sat, 31 May 2025 10:16:51 -0500 Subject: [PATCH 06/26] More pull request comments --- src/sse/sse_common.rs | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 9ab5ba82..d9f88903 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -87,16 +87,41 @@ macro_rules! boilerplate_fft_sse_oop { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( &self, input: &mut [Complex], output: &mut [Complex], - scratch: &mut [Complex], + _scratch: &mut [Complex], ) { - self.process_immutable_with_scratch(input, output, scratch); + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) + }, + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } } fn process_with_scratch(&self, buffer: &mut [Complex], scratch: &mut [Complex]) { if self.len() == 0 { From 434203b521df11db281e0b86fa4999e6aaa46fe9 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Sat, 31 May 2025 11:42:18 -0500 Subject: [PATCH 07/26] Reducing immut scratch length --- src/avx/avx_raders.rs | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/avx/avx_raders.rs b/src/avx/avx_raders.rs index b3156c5e..e68fb799 100644 --- a/src/avx/avx_raders.rs +++ b/src/avx/avx_raders.rs @@ -282,7 +282,7 @@ impl RadersAvx2 { inplace_scratch_len, outofplace_scratch_len: extra_inner_scratch, - immut_scratch_len: inplace_scratch_len * 2, + immut_scratch_len: inner_fft_len + required_inner_scratch + 1, direction, _phantom: std::marker::PhantomData, @@ -370,11 +370,9 @@ impl RadersAvx2 { let (first_input, _) = input.split_first().unwrap(); let (first_output, inner_output) = output.split_first_mut().unwrap(); + let (scratch2, extra_scratch) = scratch.split_at_mut(self.len()); + let (_, scratch) = scratch2.split_first_mut().unwrap(); - let (scratch, scratch2) = scratch.split_at_mut(scratch.len() / 2); - let (_, scratch_inner) = scratch2.split_first_mut().unwrap(); - - // perform the first of two inner FFTs self.inner_fft.process_with_scratch(inner_output, scratch); // inner_output[0] now contains the sum of elements 1..n. we want the sum of all inputs, so all we need to do is add the first input @@ -386,7 +384,7 @@ impl RadersAvx2 { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary let transmuted_inner_input: &mut [Complex] = - array_utils::workaround_transmute_mut(scratch_inner); + array_utils::workaround_transmute_mut(scratch); let transmuted_inner_output: &mut [Complex] = array_utils::workaround_transmute_mut(inner_output); avx_vector::pairwise_complex_mul_conjugated( @@ -398,10 +396,10 @@ impl RadersAvx2 { // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. // Of course, we have to conjugate it, just like we conjugated the complex multiplied above - scratch_inner[0] = scratch_inner[0] + first_input.conj(); + scratch[0] = scratch[0] + first_input.conj(); - // execute the second FFT - self.inner_fft.process_with_scratch(scratch_inner, scratch); + self.inner_fft.process_with_scratch(scratch, extra_scratch); + scratch2[0] = *first_input; // copy the final values into the output, reordering as we go unsafe { @@ -448,13 +446,13 @@ impl RadersAvx2 { // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary - let transmuted_inner_input: &mut [Complex] = - array_utils::workaround_transmute_mut(inner_input); let transmuted_inner_output: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + let transmuted_inner_input: &mut [Complex] = array_utils::workaround_transmute_mut(inner_output); avx_vector::pairwise_complex_mul_conjugated( - transmuted_inner_output, transmuted_inner_input, + transmuted_inner_output, &self.twiddles, ) }; From b55da7be16a873591dfdb48086421b21410eaac4 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Sat, 31 May 2025 11:49:37 -0500 Subject: [PATCH 08/26] Adding explicit sse radix4 method --- src/sse/sse_common.rs | 4 +--- src/sse/sse_radix4.rs | 10 ++++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index d9f88903..42e779d2 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -78,9 +78,7 @@ macro_rules! boilerplate_fft_sse_oop { input, output, self.len(), - |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) - }, + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk, &mut []), ) }; diff --git a/src/sse/sse_radix4.rs b/src/sse/sse_radix4.rs index 6d8e27e1..01d77c00 100644 --- a/src/sse/sse_radix4.rs +++ b/src/sse/sse_radix4.rs @@ -92,6 +92,16 @@ impl SseRadix4 { } } + #[target_feature(enable = "sse4.1")] + unsafe fn perform_fft_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_out_of_place(input, output, _scratch); + } + #[target_feature(enable = "sse4.1")] unsafe fn perform_fft_out_of_place( &self, From 003cf3a8a7ffeb355267032a1b2d0bf665d147e9 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Sun, 1 Jun 2025 19:04:48 -0500 Subject: [PATCH 09/26] Correcting pipeline fails --- benches/bench_rustfft_scalar.rs | 8 +++++--- src/wasm_simd/wasm_simd_butterflies.rs | 4 ++-- src/wasm_simd/wasm_simd_common.rs | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/benches/bench_rustfft_scalar.rs b/benches/bench_rustfft_scalar.rs index 24eca927..371aebf1 100644 --- a/benches/bench_rustfft_scalar.rs +++ b/benches/bench_rustfft_scalar.rs @@ -33,7 +33,7 @@ impl Fft for Noop { fn get_outofplace_scratch_len(&self) -> usize { 0 } - + fn process_immutable_with_scratch( &self, _input: &[Complex], @@ -41,8 +41,10 @@ impl Fft for Noop { _scratch: &mut [Complex], ) { } - - fn get_immutable_scratch_len(&self) -> usize { 0 } + + fn get_immutable_scratch_len(&self) -> usize { + 0 + } } impl Length for Noop { fn len(&self) -> usize { diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index e1ef1c6f..3c5a6f70 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -142,7 +142,7 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; @@ -150,7 +150,7 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index 4f4211ad..950c420e 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -69,7 +69,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } @@ -85,7 +85,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( From 75474216f2cd470cb6347e0caee60272fbf79208 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Wed, 4 Jun 2025 18:57:03 -0500 Subject: [PATCH 10/26] Changing immut default from out of place to immut for some algorithms --- src/algorithm/dft.rs | 4 +++- src/algorithm/radix3.rs | 3 ++- src/algorithm/radix4.rs | 3 ++- src/algorithm/radixn.rs | 3 ++- src/common.rs | 4 ++-- .../gen_simd_butterflies/src/templates/prime_template.hbs.rs | 2 +- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index c0e1b84a..5fa55026 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -78,7 +78,9 @@ impl Dft { self.perform_fft_immut(signal, spectrum, _scratch); } } -boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len()); +boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len(), |this: &Dft<_>| { + this.len() +}); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index c04a9a09..a8013f74 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -163,7 +163,8 @@ impl Radix3 { } } } -boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len); +boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len, |this: &Radix3<_>| this + .inplace_scratch_len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index d9bf8af0..60efd8d3 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -199,7 +199,8 @@ impl Radix4 { } } } -boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len); +boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len, |this: &Radix4<_>| this + .inplace_scratch_len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index 80189936..bdb64e0b 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -326,7 +326,8 @@ impl RadixN { } } } -boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len); +boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len, |this: &RadixN<_>| this + .inplace_scratch_len); #[inline(never)] pub(crate) unsafe fn butterfly_2( diff --git a/src/common.rs b/src/common.rs index efd1ce1a..0d0eb246 100644 --- a/src/common.rs +++ b/src/common.rs @@ -104,7 +104,7 @@ pub fn fft_error_immut( } macro_rules! boilerplate_fft_oop { - ($struct_name:ident, $len_fn:expr) => { + ($struct_name:ident, $len_fn:expr, $immut_scratch_len:expr) => { impl Fft for $struct_name { fn process_immutable_with_scratch( &self, @@ -230,7 +230,7 @@ macro_rules! boilerplate_fft_oop { } #[inline(always)] fn get_immutable_scratch_len(&self) -> usize { - self.inplace_scratch_len() + $immut_scratch_len(self) } } impl Length for $struct_name { diff --git a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs index 90382866..6937f5c3 100644 --- a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs +++ b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; From 0fb3a7a8f273296bc1d33bce8c270edcdbb481c6 Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Wed, 4 Jun 2025 19:11:50 -0500 Subject: [PATCH 11/26] Correcting autogen check --- src/neon/neon_prime_butterflies.rs | 2 +- tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neon/neon_prime_butterflies.rs b/src/neon/neon_prime_butterflies.rs index 8ef48108..688b5a20 100644 --- a/src/neon/neon_prime_butterflies.rs +++ b/src/neon/neon_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs index 6937f5c3..90382866 100644 --- a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs +++ b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; From a12b9cb4049eca108798aeea4df61e8af2a37c4f Mon Sep 17 00:00:00 2001 From: Michael Ciraci Date: Wed, 4 Jun 2025 19:19:26 -0500 Subject: [PATCH 12/26] Explicitly using fft_error_immut in macro --- src/neon/neon_butterflies.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index d985b947..d12c544a 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -188,7 +188,7 @@ macro_rules! boilerplate_fft_neon_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi_immut(input, output) }; @@ -196,7 +196,7 @@ macro_rules! boilerplate_fft_neon_common_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( From 6c4459929d47c9f3ffa88e6d789cd6a614470082 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Wed, 4 Jun 2025 23:00:39 -0700 Subject: [PATCH 13/26] Keep the usage of ffr_error_immut consistent with the rest of the error functions --- src/neon/neon_butterflies.rs | 6 +++--- src/neon/neon_common.rs | 2 +- src/neon/neon_prime_butterflies.rs | 2 +- src/sse/sse_butterflies.rs | 6 +++--- src/sse/sse_common.rs | 4 ++-- src/sse/sse_prime_butterflies.rs | 2 +- src/sse/sse_radix4.rs | 2 +- src/wasm_simd/wasm_simd_butterflies.rs | 6 +++--- src/wasm_simd/wasm_simd_common.rs | 4 ++-- src/wasm_simd/wasm_simd_prime_butterflies.rs | 2 +- src/wasm_simd/wasm_simd_radix4.rs | 2 +- .../src/templates/prime_template.hbs.rs | 2 +- 12 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index d12c544a..d985b947 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -188,7 +188,7 @@ macro_rules! boilerplate_fft_neon_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi_immut(input, output) }; @@ -196,7 +196,7 @@ macro_rules! boilerplate_fft_neon_common_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/neon/neon_common.rs b/src/neon/neon_common.rs index adeeb86c..635cf30f 100644 --- a/src/neon/neon_common.rs +++ b/src/neon/neon_common.rs @@ -85,7 +85,7 @@ macro_rules! boilerplate_fft_neon_oop { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/neon/neon_prime_butterflies.rs b/src/neon/neon_prime_butterflies.rs index 688b5a20..8ef48108 100644 --- a/src/neon/neon_prime_butterflies.rs +++ b/src/neon/neon_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index b6596a8c..7ce8d217 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::DoubleBuf; use crate::array_utils::{workaround_transmute, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -184,7 +184,7 @@ macro_rules! boilerplate_fft_sse_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; @@ -192,7 +192,7 @@ macro_rules! boilerplate_fft_sse_common_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 42e779d2..86593abe 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -69,7 +69,7 @@ macro_rules! boilerplate_fft_sse_oop { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } @@ -85,7 +85,7 @@ macro_rules! boilerplate_fft_sse_oop { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/sse/sse_prime_butterflies.rs b/src/sse/sse_prime_butterflies.rs index 2ffaf978..ac9d5005 100644 --- a/src/sse/sse_prime_butterflies.rs +++ b/src/sse/sse_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/sse/sse_radix4.rs b/src/sse/sse_radix4.rs index 01d77c00..d1912eeb 100644 --- a/src/sse/sse_radix4.rs +++ b/src/sse/sse_radix4.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::sync::Arc; use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 3c5a6f70..231c4340 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -6,7 +6,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; @@ -142,7 +142,7 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; @@ -150,7 +150,7 @@ macro_rules! boilerplate_fft_wasm_simd_common_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index 950c420e..0b46f4cd 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -69,7 +69,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } @@ -85,7 +85,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/wasm_simd/wasm_simd_prime_butterflies.rs b/src/wasm_simd/wasm_simd_prime_butterflies.rs index c56585e9..4d46e6b4 100644 --- a/src/wasm_simd/wasm_simd_prime_butterflies.rs +++ b/src/wasm_simd/wasm_simd_prime_butterflies.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; diff --git a/src/wasm_simd/wasm_simd_radix4.rs b/src/wasm_simd/wasm_simd_radix4.rs index 679ddb92..76c60c49 100644 --- a/src/wasm_simd/wasm_simd_radix4.rs +++ b/src/wasm_simd/wasm_simd_radix4.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::sync::Arc; use crate::array_utils::{self, bitreversed_transpose, workaround_transmute_mut}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs index 90382866..6937f5c3 100644 --- a/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs +++ b/tools/gen_simd_butterflies/src/templates/prime_template.hbs.rs @@ -8,7 +8,7 @@ use crate::{common::FftNum, FftDirection}; use crate::array_utils; use crate::array_utils::workaround_transmute_mut; use crate::array_utils::DoubleBuf; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::twiddles; use crate::{Direction, Fft, Length}; From 0281042af6fbf385d5f4cbd29ae7f52022681c01 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Wed, 4 Jun 2025 23:39:17 -0700 Subject: [PATCH 14/26] Dft doesn't need any scratch --- src/algorithm/dft.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index 5fa55026..f0ef525d 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -78,9 +78,7 @@ impl Dft { self.perform_fft_immut(signal, spectrum, _scratch); } } -boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len(), |this: &Dft<_>| { - this.len() -}); +boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len(), |_: &Dft<_>| 0); #[cfg(test)] mod unit_tests { From 59de6f9b30d0257594222aeb5985fb315b333555 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Wed, 4 Jun 2025 23:39:36 -0700 Subject: [PATCH 15/26] Updated comment in raders algorithm immutable impl --- src/algorithm/raders_algorithm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index 73382932..3a4df60e 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -135,7 +135,7 @@ impl RadersAlgorithm { let (input_first, input) = input.split_first().unwrap(); let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1); - // copy the input into the output, reordering as we go. also compute a sum of all elements + // copy the input into the scratch space, reordering as we go let mut input_index = 1; for output_element in scratch.iter_mut() { input_index = (input_index * self.primitive_root) % self.len; From ec19a53f4b664fa3373f747b73f82e128ad4688f Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Wed, 4 Jun 2025 23:40:05 -0700 Subject: [PATCH 16/26] Make sure error messages are correct in butterfly boilerplate --- src/algorithm/butterflies.rs | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 2a39bca1..7bae332b 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -47,7 +47,7 @@ macro_rules! boilerplate_fft_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( @@ -56,7 +56,31 @@ macro_rules! boilerplate_fft_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - self.process_immutable_with_scratch(input, output, _scratch); + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + self.perform_fft_butterfly(DoubleBuf { + input: in_chunk, + output: out_chunk, + }) + }; + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { if buffer.len() < self.len() { From 3ce403968c454b5f1b70a23d7bad6d651a58d4a2 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Wed, 4 Jun 2025 23:49:27 -0700 Subject: [PATCH 17/26] Update scratch requests for RadixK --- src/algorithm/radix3.rs | 38 ++++++++++++++++++++++++++++++++++---- src/algorithm/radix4.rs | 4 +++- src/algorithm/radixn.rs | 5 ++++- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index a8013f74..cabd2118 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -32,8 +32,10 @@ pub struct Radix3 { len: usize, direction: FftDirection, + inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, } impl Radix3 { @@ -108,6 +110,7 @@ impl Radix3 { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len: base_inplace_scratch, } } @@ -124,13 +127,39 @@ impl Radix3 { output: &mut [Complex], scratch: &mut [Complex], ) { - self.perform_fft_out_of_place(input, output, scratch); + // copy the data into the output vector + if self.len() == self.base_len { + output.copy_from_slice(input); + } else { + bitreversed_transpose::, 3>(self.base_len, input, output); + } + + // Base-level FFTs + self.base_fft.process_with_scratch(output, scratch); + + // cross-FFTs + const ROW_COUNT: usize = 3; + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + while cross_fft_len < output.len() { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = num_columns * (ROW_COUNT - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } } #[inline] fn perform_fft_out_of_place( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], scratch: &mut [Complex], ) { @@ -142,7 +171,8 @@ impl Radix3 { } // Base-level FFTs - self.base_fft.process_with_scratch(output, scratch); + let base_scratch = if scratch.len() > 0 { scratch } else { input }; + self.base_fft.process_with_scratch(output, base_scratch); // cross-FFTs const ROW_COUNT: usize = 3; @@ -164,7 +194,7 @@ impl Radix3 { } } boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len, |this: &Radix3<_>| this - .inplace_scratch_len); + .immut_scratch_len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 60efd8d3..2800a402 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -35,6 +35,7 @@ pub struct Radix4 { direction: FftDirection, inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, } impl Radix4 { @@ -114,6 +115,7 @@ impl Radix4 { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len: base_inplace_scratch, } } @@ -200,7 +202,7 @@ impl Radix4 { } } boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len, |this: &Radix4<_>| this - .inplace_scratch_len); + .immut_scratch_len); #[cfg(test)] mod unit_tests { diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index bdb64e0b..b52765c7 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -43,8 +43,10 @@ pub(crate) struct RadixN { len: usize, direction: FftDirection, + inplace_scratch_len: usize, outofplace_scratch_len: usize, + immut_scratch_len: usize, } impl RadixN { @@ -148,6 +150,7 @@ impl RadixN { inplace_scratch_len, outofplace_scratch_len, + immut_scratch_len: base_inplace_scratch, } } @@ -327,7 +330,7 @@ impl RadixN { } } boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len, |this: &RadixN<_>| this - .inplace_scratch_len); + .immut_scratch_len); #[inline(never)] pub(crate) unsafe fn butterfly_2( From 99824d59a69a8853f145613ab7b3bf028841da5d Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:04:42 -0700 Subject: [PATCH 18/26] Make sure to use correct error messages in avx butterflies --- src/avx/avx32_butterflies.rs | 33 ++++++++++++++++++++++++++++++--- src/avx/avx64_butterflies.rs | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index eaaa3b5f..2b6a301a 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -69,7 +69,7 @@ macro_rules! boilerplate_fft_simd_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( @@ -78,7 +78,34 @@ macro_rules! boilerplate_fft_simd_butterfly { output: &mut [Complex], _scratch: &mut [Complex], ) { - self.process_immutable_with_scratch(input, output, _scratch); + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices + let input_slice = workaround_transmute(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_f32(DoubleBuf { + input: input_slice, + output: output_slice, + }); + } + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { if buffer.len() < self.len() { @@ -224,7 +251,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index 538c8098..2d1639ac 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -67,16 +67,43 @@ macro_rules! boilerplate_fft_simd_butterfly { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( &self, input: &mut [Complex], output: &mut [Complex], - scratch: &mut [Complex], + _scratch: &mut [Complex], ) { - self.process_immutable_with_scratch(input, output, scratch); + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices + let input_slice = workaround_transmute(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_f64(DoubleBuf { + input: input_slice, + output: output_slice, + }); + } + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } } fn process_with_scratch(&self, buffer: &mut [Complex], _scratch: &mut [Complex]) { if buffer.len() < self.len() { @@ -224,7 +251,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut(self.len(), input.len(), output.len(), 0, 0); } } fn process_outofplace_with_scratch( From 5e9af29ed7d9813eebb6ee276e61f6444780de6f Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:06:49 -0700 Subject: [PATCH 19/26] Keep scratch usage consistent in avx large butterflies --- src/avx/avx32_butterflies.rs | 17 ++++------------- src/avx/avx64_butterflies.rs | 31 +++++++++++-------------------- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index 2b6a301a..e0c85399 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -195,12 +195,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } #[inline] - fn perform_fft_immut( - &self, - input: &[Complex], - output: &mut [Complex], - _scratch: &mut [Complex], - ) { + fn perform_fft_immut(&self, input: &[Complex], output: &mut [Complex]) { // Perform the column FFTs // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available unsafe { self.column_butterflies_and_transpose(input, output) }; @@ -216,7 +211,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { input: &mut [Complex], output: &mut [Complex], ) { - self.perform_fft_immut(input, output, &mut []); + self.perform_fft_immut(input, output); } } impl Fft for $struct_name { @@ -224,7 +219,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { &self, input: &[Complex], output: &mut [Complex], - scratch: &mut [Complex], + _scratch: &mut [Complex], ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us @@ -237,15 +232,11 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let transmuted_scratch: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(scratch) }; let result = array_utils::iter_chunks_zipped( transmuted_input, transmuted_output, self.len(), - |in_chunk, out_chunk| { - self.perform_fft_immut(in_chunk, out_chunk, transmuted_scratch) - }, + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk), ); if result.is_err() { diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index 2d1639ac..e6bc7d0d 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -195,13 +195,14 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } #[inline] - fn perform_fft_immut( - &self, - input: &[Complex], - output: &mut [Complex], - _scratch: &mut [Complex], - ) { - self.perform_fft_out_of_place(input, output); + fn perform_fft_immut(&self, input: &[Complex], output: &mut [Complex]) { + // Perform the column FFTs + // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available + unsafe { self.column_butterflies_and_transpose(input, output) }; + + // process the row FFTs in-place in the output buffer + // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available + unsafe { self.row_butterflies(output) }; } #[inline] @@ -210,13 +211,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { input: &[Complex], output: &mut [Complex], ) { - // Perform the column FFTs - // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available - unsafe { self.column_butterflies_and_transpose(input, output) }; - - // process the row FFTs in-place in the output buffer - // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available - unsafe { self.row_butterflies(output) }; + self.perform_fft_immut(input, output); } } impl Fft for $struct_name { @@ -224,7 +219,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { &self, input: &[Complex], output: &mut [Complex], - scratch: &mut [Complex], + _scratch: &mut [Complex], ) { if input.len() < self.len() || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us @@ -237,15 +232,11 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let transmuted_scratch: &mut [Complex] = - unsafe { array_utils::workaround_transmute_mut(scratch) }; let result = array_utils::iter_chunks_zipped( transmuted_input, transmuted_output, self.len(), - |in_chunk, out_chunk| { - self.perform_fft_immut(in_chunk, out_chunk, transmuted_scratch) - }, + |in_chunk, out_chunk| self.perform_fft_immut(in_chunk, out_chunk), ); if result.is_err() { From 8f66acb8ff7b1f95e3e5862f57e40779088955df Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:21:33 -0700 Subject: [PATCH 20/26] Consistently forward from perform_fft_out_of_place to perform_fft_immut --- src/avx/avx_bluesteins.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index f04fa1cb..89a2efa4 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -317,15 +317,6 @@ impl BluesteinsAvx { input: &[Complex], output: &mut [Complex], scratch: &mut [Complex], - ) { - self.perform_fft_out_of_place(input, output, scratch); - } - - fn perform_fft_out_of_place( - &self, - input: &[Complex], - output: &mut [Complex], - scratch: &mut [Complex], ) { let (inner_input, inner_scratch) = scratch .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR); @@ -375,6 +366,15 @@ impl BluesteinsAvx { self.finalize_bluesteins(transmuted_inner_input, transmuted_output) } } + + fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, scratch); + } } #[cfg(test)] From f6148eaaca9e86a1985c3149468293fee9ffa708 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:22:28 -0700 Subject: [PATCH 21/26] Skip the initial copy in avx mixed radix partial butterflies --- src/avx/avx_mixed_radix.rs | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index 7686aa11..a61f7909 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -238,16 +238,6 @@ macro_rules! mixedradix_column_butterflies { // Normally, we can fit COMPLEX_PER_VECTOR complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR; if partial_remainder > 0 { - // We need to copy the partial remainders to the buffer - for c in 0..self.len() / len_per_row { - let cs = c * len_per_row + len_per_row - partial_remainder; - match partial_remainder { - 1 => buffer.store_partial1_complex(input.load_partial1_complex(cs), cs), - 2 => buffer.store_partial2_complex(input.load_partial2_complex(cs), cs), - 3 => buffer.store_partial3_complex(input.load_partial3_complex(cs), cs), - _ => unreachable!(), - } - } let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR; let partial_remainder_twiddle_base = self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN; @@ -259,7 +249,7 @@ macro_rules! mixedradix_column_butterflies { let mut columns = [AvxVector::zero(); ROW_COUNT]; for i in 0..ROW_COUNT { columns[i] = - buffer.load_partial3_complex(partial_remainder_base + len_per_row * i); + input.load_partial3_complex(partial_remainder_base + len_per_row * i); } // apply our butterfly function down the columns @@ -282,12 +272,12 @@ macro_rules! mixedradix_column_butterflies { let mut columns = [AvxVector::zero(); ROW_COUNT]; if partial_remainder == 1 { for i in 0..ROW_COUNT { - columns[i] = buffer + columns[i] = input .load_partial1_complex(partial_remainder_base + len_per_row * i); } } else { for i in 0..ROW_COUNT { - columns[i] = buffer + columns[i] = input .load_partial2_complex(partial_remainder_base + len_per_row * i); } } From a34a7df7603ae586237cbb7001bbfee620b354fe Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:45:38 -0700 Subject: [PATCH 22/26] Use correct error messages in avx boilerplate --- src/avx/avx_bluesteins.rs | 2 +- src/avx/avx_mixed_radix.rs | 2 +- src/avx/avx_raders.rs | 2 +- src/avx/mod.rs | 12 ++++++------ 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index 89a2efa4..d85020ba 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -5,7 +5,7 @@ use num_complex::Complex; use num_integer::div_ceil; use num_traits::Zero; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{array_utils, twiddles, FftDirection}; use crate::{Direction, Fft, FftNum, Length}; diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index a61f7909..a64476c9 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -5,7 +5,7 @@ use num_complex::Complex; use num_integer::div_ceil; use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{Direction, Fft, FftDirection, FftNum, Length}; use super::{AvxNum, CommonSimdData}; diff --git a/src/avx/avx_raders.rs b/src/avx/avx_raders.rs index e68fb799..439395e3 100644 --- a/src/avx/avx_raders.rs +++ b/src/avx/avx_raders.rs @@ -8,7 +8,7 @@ use num_traits::Zero; use primal_check::miller_rabin; use strength_reduce::StrengthReducedUsize; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{array_utils, FftDirection}; use crate::{math_utils, twiddles}; use crate::{Direction, Fft, FftNum, Length}; diff --git a/src/avx/mod.rs b/src/avx/mod.rs index 06061574..4aec9df8 100644 --- a/src/avx/mod.rs +++ b/src/avx/mod.rs @@ -43,7 +43,7 @@ macro_rules! boilerplate_avx_fft { || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut( + fft_error_immut( self.len(), input.len(), output.len(), @@ -64,11 +64,11 @@ macro_rules! boilerplate_avx_fft { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + fft_error_immut( self.len(), input.len(), output.len(), - self.get_outofplace_scratch_len(), + required_scratch, scratch.len(), ) } @@ -194,7 +194,7 @@ macro_rules! boilerplate_avx_fft_commondata { || output.len() != input.len() { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut( + fft_error_immut( self.len(), input.len(), output.len(), @@ -215,11 +215,11 @@ macro_rules! boilerplate_avx_fft_commondata { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace( + fft_error_immut( self.len(), input.len(), output.len(), - self.get_outofplace_scratch_len(), + self.get_immutable_scratch_len(), scratch.len(), ); } From 4f394c293c59b75c024de7110940b8c8dd74d498 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:49:12 -0700 Subject: [PATCH 23/26] Use correct error messages in scalar boilerplate --- src/algorithm/dft.rs | 2 +- src/algorithm/radix3.rs | 2 +- src/algorithm/radix4.rs | 2 +- src/algorithm/radixn.rs | 2 +- src/common.rs | 10 ++++++++-- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index f0ef525d..95f781f1 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -2,7 +2,7 @@ use num_complex::Complex; use num_traits::Zero; use crate::array_utils; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{twiddles, FftDirection}; use crate::{Direction, Fft, FftNum, Length}; diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index cabd2118..182ab4f4 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -5,7 +5,7 @@ use num_complex::Complex; use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9}; use crate::algorithm::radixn::butterfly_3; use crate::array_utils::{self, bitreversed_transpose, compute_logarithm}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 2800a402..b040cfe0 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -7,7 +7,7 @@ use crate::algorithm::butterflies::{ }; use crate::algorithm::radixn::butterfly_4; use crate::array_utils::{self, bitreversed_transpose}; -use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index b52765c7..a85a6bf6 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use num_complex::Complex; use crate::array_utils::{self, factor_transpose, Load, LoadStore, TransposeFactor}; -use crate::common::{fft_error_inplace, fft_error_outofplace, RadixFactor}; +use crate::common::{fft_error_immut, fft_error_inplace, fft_error_outofplace, RadixFactor}; use crate::{common::FftNum, twiddles, FftDirection}; use crate::{Direction, Fft, Length}; diff --git a/src/common.rs b/src/common.rs index 0d0eb246..1670ebf5 100644 --- a/src/common.rs +++ b/src/common.rs @@ -122,7 +122,7 @@ macro_rules! boilerplate_fft_oop { || scratch.len() < required_scratch { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - crate::common::fft_error_immut( + fft_error_immut( self.len(), input.len(), output.len(), @@ -142,7 +142,13 @@ macro_rules! boilerplate_fft_oop { if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_immut( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); } } fn process_outofplace_with_scratch( From 801cca2adfeeea663d26f2cc8d849c04e5132618 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 00:58:59 -0700 Subject: [PATCH 24/26] Only need one perform_oop_fft_butterfly_multi since they're identical --- src/neon/neon_butterflies.rs | 54 +----------------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index d985b947..571e8da0 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -52,38 +52,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { Ok(()) } - // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait pub(crate) unsafe fn perform_oop_fft_butterfly_multi( - &self, - input: &mut [Complex], - output: &mut [Complex], - ) -> Result<(), ()> { - let len = input.len(); - let alldone = array_utils::iter_chunks_zipped_mut( - input, - output, - 2 * self.len(), - |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_parallel_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }, - ); - if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); - let output_slice = workaround_transmute_mut(output); - self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], - output: &mut output_slice[len - self.len()..], - }) - } - Ok(()) - } - - pub(crate) unsafe fn perform_oop_fft_butterfly_multi_immut( &self, input: &[Complex], output: &mut [Complex], @@ -137,28 +106,7 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { } // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait - //#[target_feature(enable = "neon")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( - &self, - input: &mut [Complex], - output: &mut [Complex], - ) -> Result<(), ()> { - array_utils::iter_chunks_zipped_mut( - input, - output, - self.len(), - |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); - let output_slice = workaround_transmute_mut(out_chunk); - self.perform_fft_contiguous(DoubleBuf { - input: input_slice, - output: output_slice, - }) - }, - ) - } - - pub(crate) unsafe fn perform_oop_fft_butterfly_multi_immut( &self, input: &[Complex], output: &mut [Complex], @@ -191,7 +139,7 @@ macro_rules! boilerplate_fft_neon_common_butterfly { fft_error_immut(self.len(), input.len(), output.len(), 0, 0); return; // Unreachable, because fft_error_immut asserts, but it helps codegen to put it here } - let result = unsafe { self.perform_oop_fft_butterfly_multi_immut(input, output) }; + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; if result.is_err() { // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, From 4ce7490d4a219a95194b81d1d838911f4b609441 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 01:03:23 -0700 Subject: [PATCH 25/26] Consistently forward from out of place to immutable --- src/neon/neon_radix4.rs | 20 ++++++++++---------- src/sse/sse_radix4.rs | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index 5fd1699f..2fbfe4e4 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -88,16 +88,6 @@ impl NeonRadix4 { input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], - ) { - self.perform_fft_out_of_place(input, output, _scratch); - } - - //#[target_feature(enable = "neon")] - unsafe fn perform_fft_out_of_place( - &self, - input: &[Complex], - output: &mut [Complex], - _scratch: &mut [Complex], ) { // copy the data into the output vector if self.len() == self.base_len { @@ -135,6 +125,16 @@ impl NeonRadix4 { cross_fft_len *= ROW_COUNT; } } + + //#[target_feature(enable = "neon")] + unsafe fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, _scratch); + } } boilerplate_fft_neon_oop!(NeonRadix4, |this: &NeonRadix4<_, _>| this.len); diff --git a/src/sse/sse_radix4.rs b/src/sse/sse_radix4.rs index d1912eeb..43eb3971 100644 --- a/src/sse/sse_radix4.rs +++ b/src/sse/sse_radix4.rs @@ -98,16 +98,6 @@ impl SseRadix4 { input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], - ) { - self.perform_fft_out_of_place(input, output, _scratch); - } - - #[target_feature(enable = "sse4.1")] - unsafe fn perform_fft_out_of_place( - &self, - input: &[Complex], - output: &mut [Complex], - _scratch: &mut [Complex], ) { // copy the data into the output vector if self.len() == self.base_len { @@ -145,6 +135,16 @@ impl SseRadix4 { cross_fft_len *= ROW_COUNT; } } + + #[target_feature(enable = "sse4.1")] + unsafe fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, _scratch); + } } boilerplate_fft_sse_oop!(SseRadix4, |this: &SseRadix4<_, _>| this.len); From f04e8ce572264c11f154fbb7dbf35fb53ce34b1f Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Thu, 5 Jun 2025 01:18:32 -0700 Subject: [PATCH 26/26] Consistently forward from out of place to immutable --- src/wasm_simd/wasm_simd_radix4.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/wasm_simd/wasm_simd_radix4.rs b/src/wasm_simd/wasm_simd_radix4.rs index 76c60c49..73682570 100644 --- a/src/wasm_simd/wasm_simd_radix4.rs +++ b/src/wasm_simd/wasm_simd_radix4.rs @@ -88,16 +88,6 @@ impl WasmSimdRadix4 { input: &[Complex], output: &mut [Complex], _scratch: &mut [Complex], - ) { - self.perform_fft_out_of_place(input, output, _scratch); - } - - #[target_feature(enable = "simd128")] - unsafe fn perform_fft_out_of_place( - &self, - input: &[Complex], - output: &mut [Complex], - _scratch: &mut [Complex], ) { // copy the data into the output vector if self.len() == self.base_len { @@ -135,6 +125,16 @@ impl WasmSimdRadix4 { cross_fft_len *= ROW_COUNT; } } + + #[target_feature(enable = "simd128")] + unsafe fn perform_fft_out_of_place( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + self.perform_fft_immut(input, output, _scratch); + } } boilerplate_fft_wasm_simd_oop!(WasmSimdRadix4, |this: &WasmSimdRadix4<_, _>| this.len);