diff --git a/Cargo.toml b/Cargo.toml index 270c094..8b637e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ categories = ["algorithms", "compression", "multimedia::encoding", "science"] license = "MIT OR Apache-2.0" [features] -default = ["avx", "sse", "neon"] +default = ["avx","neon"] # On x86_64, the "avx" feature enables compilation of AVX-acclerated code. # Similarly, the "sse" feature enables compilation of SSE-accelerated code. diff --git a/src/algorithm/bluesteins_algorithm.rs b/src/algorithm/bluesteins_algorithm.rs index 9cc3642..deca84a 100644 --- a/src/algorithm/bluesteins_algorithm.rs +++ b/src/algorithm/bluesteins_algorithm.rs @@ -137,6 +137,18 @@ impl BluesteinsAlgorithm { } } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // TODO - Is there a better way to do this? + let (mut input_scratch, scratch) = scratch.split_at_mut(input.len()); + input_scratch.copy_from_slice(input); + self.process_outofplace_with_scratch(&mut input_scratch, output, scratch); + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/butterflies.rs b/src/algorithm/butterflies.rs index 1292618..f661756 100644 --- a/src/algorithm/butterflies.rs +++ b/src/algorithm/butterflies.rs @@ -17,6 +17,39 @@ macro_rules! boilerplate_fft_butterfly { } } impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + self.perform_fft_butterfly(DoubleBuf { + input: in_chunk, + output: out_chunk, + }) + }; + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -29,7 +62,7 @@ macro_rules! boilerplate_fft_butterfly { return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -56,7 +89,7 @@ macro_rules! boilerplate_fft_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| unsafe { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| unsafe { self.perform_fft_butterfly(chunk) }); @@ -104,6 +137,15 @@ impl Butterfly1 { } } impl Fft for Butterfly1 { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + output.copy_from_slice(&input); + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index e0b700c..0e956fd 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -68,6 +68,15 @@ impl Dft { } } } + + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } } boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len()); diff --git a/src/algorithm/good_thomas_algorithm.rs b/src/algorithm/good_thomas_algorithm.rs index 005ba52..e01328f 100644 --- a/src/algorithm/good_thomas_algorithm.rs +++ b/src/algorithm/good_thomas_algorithm.rs @@ -241,6 +241,15 @@ impl GoodThomasAlgorithm { self.reindex_output(scratch, buffer); } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -384,6 +393,38 @@ impl GoodThomasAlgorithmSmall { } } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // These asserts are for the unsafe blocks down below. we're relying on the optimizer to get rid of this assert + assert_eq!(self.len(), input.len()); + assert_eq!(self.len(), output.len()); + + let (input_map, output_map) = self.input_output_map.split_at(self.len()); + + // copy the input using our reordering mapping + for (output_element, &input_index) in output.iter_mut().zip(input_map.iter()) { + *output_element = input[input_index]; + } + + // run FFTs of size `width` + self.width_size_fft.process_with_scratch(output, scratch); + + // transpose + unsafe { array_utils::transpose_small(self.width, self.height, output, scratch) }; + + // run FFTs of size 'height' + self.height_size_fft.process_with_scratch(scratch, output); + + // copy to the output, using our output redordeing mapping + for (input_element, &output_index) in scratch.iter().zip(output_map.iter()) { + output[output_index] = *input_element; + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/mixed_radix.rs b/src/algorithm/mixed_radix.rs index 808ff0f..ac496a6 100644 --- a/src/algorithm/mixed_radix.rs +++ b/src/algorithm/mixed_radix.rs @@ -151,6 +151,45 @@ impl MixedRadix { transpose::transpose(scratch, buffer, self.width, self.height); } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // STEP 1: transpose + transpose::transpose(input, output, self.width, self.height); + + // STEP 2: perform FFTs of size `height` + // let height_scratch = if scratch.len() > input.len() { + // &mut scratch[..] + // } else { + // &mut input[..] + // }; + self.height_size_fft + .process_with_scratch(output, scratch); + + // STEP 3: Apply twiddle factors + for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { + *element = *element * twiddle; + } + + // STEP 4: transpose again + transpose::transpose(output, scratch, self.height, self.width); + + // STEP 5: perform FFTs of size `width` + // let width_scratch = if scratch.len() > output.len() { + // &mut scratch[..] + // } else { + // &mut output[..] + // }; + self.width_size_fft + .process_with_scratch(scratch, output); + + // STEP 6: transpose again + transpose::transpose(scratch, output, self.width, self.height); + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], @@ -302,6 +341,34 @@ impl MixedRadixSmall { unsafe { array_utils::transpose_small(self.width, self.height, scratch, buffer) }; } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // SIX STEP FFT: + // STEP 1: transpose + unsafe { array_utils::transpose_small(self.width, self.height, input, output) }; + + // STEP 2: perform FFTs of size `height` + self.height_size_fft.process_with_scratch(output, scratch); + + // STEP 3: Apply twiddle factors + for (element, twiddle) in output.iter_mut().zip(self.twiddles.iter()) { + *element = *element * twiddle; + } + + // STEP 4: transpose again + unsafe { array_utils::transpose_small(self.height, self.width, output, scratch) }; + + // STEP 5: perform FFTs of size `width` + self.width_size_fft.process_with_scratch(scratch, output); + + // STEP 6: transpose again + unsafe { array_utils::transpose_small(self.width, self.height, scratch, output) }; + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/raders_algorithm.rs b/src/algorithm/raders_algorithm.rs index 7f059dd..c9fa76c 100644 --- a/src/algorithm/raders_algorithm.rs +++ b/src/algorithm/raders_algorithm.rs @@ -119,6 +119,15 @@ impl RadersAlgorithm { } } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index d392f8c..358f858 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -118,6 +118,15 @@ impl Radix3 { self.outofplace_scratch_len } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 33a804e..ffc998a 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -124,6 +124,15 @@ impl Radix4 { self.outofplace_scratch_len } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs index bd4d35c..4050329 100644 --- a/src/algorithm/radixn.rs +++ b/src/algorithm/radixn.rs @@ -174,6 +174,15 @@ impl RadixN { self.outofplace_scratch_len } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/array_utils.rs b/src/array_utils.rs index 8dd0aba..e9ac585 100644 --- a/src/array_utils.rs +++ b/src/array_utils.rs @@ -146,6 +146,29 @@ mod unit_tests { // Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations // Returns Ok() if every element ended up in a chunk, Err() if there was a remainder pub fn iter_chunks( + mut buffer: &[T], + chunk_size: usize, + mut chunk_fn: impl FnMut(&[T]), +) -> Result<(), ()> { + // Loop over the buffer, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer.len() >= chunk_size { + let (head, tail) = buffer.split_at(chunk_size); + buffer = tail; + + chunk_fn(head); + } + + // We have a remainder if there's data still in the buffer -- in which case we want to indicate to the caller that there was an unwanted remainder + if buffer.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations +// Returns Ok() if every element ended up in a chunk, Err() if there was a remainder +pub fn iter_chunks_mut( mut buffer: &mut [T], chunk_size: usize, mut chunk_fn: impl FnMut(&mut [T]), @@ -169,6 +192,44 @@ pub fn iter_chunks( // Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations // Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder pub fn iter_chunks_zipped( + mut buffer1: &[T], + mut buffer2: &mut [T], + chunk_size: usize, + mut chunk_fn: impl FnMut(&[T], &mut [T]), +) -> Result<(), ()> { + // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size + let uneven = if buffer1.len() > buffer2.len() { + buffer1 = &buffer1[..buffer2.len()]; + true + } else if buffer2.len() < buffer1.len() { + buffer2 = &mut buffer2[..buffer1.len()]; + true + } else { + false + }; + + // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each + while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size { + let (head1, tail1) = buffer1.split_at(chunk_size); + buffer1 = tail1; + + let (head2, tail2) = buffer2.split_at_mut(chunk_size); + buffer2 = tail2; + + chunk_fn(head1, head2); + } + + // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder + if !uneven && buffer1.len() == 0 { + Ok(()) + } else { + Err(()) + } +} + +// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations +// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder +pub fn iter_chunks_zipped_mut( mut buffer1: &mut [T], mut buffer2: &mut [T], chunk_size: usize, diff --git a/src/avx/avx32_butterflies.rs b/src/avx/avx32_butterflies.rs index 87f9c10..8086960 100644 --- a/src/avx/avx32_butterflies.rs +++ b/src/avx/avx32_butterflies.rs @@ -5,7 +5,7 @@ use std::mem::MaybeUninit; use num_complex::Complex; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; +use crate::array_utils::*; use crate::array_utils::DoubleBuf; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; @@ -37,6 +37,42 @@ macro_rules! boilerplate_fft_simd_butterfly { } impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices + let input_slice = workaround_transmute(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_f32(DoubleBuf { + input: input_slice, + output: output_slice, + }); + } + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -49,7 +85,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -79,7 +115,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices self.perform_fft_f32(workaround_transmute_mut::<_, Complex>(chunk)); @@ -155,6 +191,21 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { }; } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + // Perform the column FFTs + // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available + unsafe { self.column_butterflies_and_transpose(input, output) }; + + // process the row FFTs in-place in the output buffer + // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available + unsafe { self.row_butterflies(output) }; + } + #[inline] fn perform_fft_out_of_place( &self, @@ -171,6 +222,39 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } } impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = + unsafe { array_utils::workaround_transmute(input) }; + let transmuted_output: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(output) }; + let transmuted_scratch: &mut [Complex] = + unsafe { array_utils::workaround_transmute_mut(scratch) }; + let result = array_utils::iter_chunks_zipped( + transmuted_input, + transmuted_output, + self.len(), + |in_chunk, out_chunk| self.perform_fft_out_of_place_immut(in_chunk, out_chunk, transmuted_scratch), + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -188,7 +272,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( transmuted_input, transmuted_output, self.len(), @@ -216,7 +300,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(buffer) }; let transmuted_scratch: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, transmuted_scratch) }); diff --git a/src/avx/avx64_butterflies.rs b/src/avx/avx64_butterflies.rs index c15d545..57846fc 100644 --- a/src/avx/avx64_butterflies.rs +++ b/src/avx/avx64_butterflies.rs @@ -5,7 +5,7 @@ use std::mem::MaybeUninit; use num_complex::Complex; use crate::array_utils; -use crate::array_utils::workaround_transmute_mut; +use crate::array_utils::*; use crate::array_utils::DoubleBuf; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles}; @@ -35,6 +35,14 @@ macro_rules! boilerplate_fft_simd_butterfly { } } impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -47,7 +55,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -77,7 +85,7 @@ macro_rules! boilerplate_fft_simd_butterfly { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { unsafe { // Specialization workaround: See the comments in FftPlannerAvx::new() for why we have to transmute these slices self.perform_fft_f64(workaround_transmute_mut::<_, Complex>(chunk)); @@ -155,6 +163,15 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { }; } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + #[inline] fn perform_fft_out_of_place( &self, @@ -171,6 +188,14 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { } } impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -188,7 +213,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(input) }; let transmuted_output: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(output) }; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( transmuted_input, transmuted_output, self.len(), @@ -216,7 +241,7 @@ macro_rules! boilerplate_fft_simd_butterfly_with_scratch { unsafe { array_utils::workaround_transmute_mut(buffer) }; let transmuted_scratch: &mut [Complex] = unsafe { array_utils::workaround_transmute_mut(scratch) }; - let result = array_utils::iter_chunks(transmuted_buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(transmuted_buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, transmuted_scratch) }); diff --git a/src/avx/avx_bluesteins.rs b/src/avx/avx_bluesteins.rs index 30e8f8e..63db326 100644 --- a/src/avx/avx_bluesteins.rs +++ b/src/avx/avx_bluesteins.rs @@ -311,6 +311,61 @@ impl BluesteinsAvx { } } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + let (inner_input, inner_scratch) = scratch + .split_at_mut(self.inner_fft_multiplier.len() * A::VectorType::COMPLEX_PER_VECTOR); + + // do the necessary setup for bluestein's algorithm: copy the data to the inner buffers, apply some twiddle factors, zero out the rest of the inner buffer + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = array_utils::workaround_transmute(input); + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + + self.prepare_bluesteins(transmuted_input, transmuted_inner_input) + } + + // run our inner forward FFT + self.common_data + .inner_fft + .process_with_scratch(inner_input, inner_scratch); + + // Multiply our inner FFT output by our precomputed data. Then, conjugate the result to set up for an inverse FFT. + // We can conjugate the result of multiplication by conjugating both inputs. We pre-conjugated the multiplier array, + // so we just need to conjugate inner_input, which the pairwise_complex_multiply_conjugated function will handle + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + + Self::pairwise_complex_multiply_conjugated( + transmuted_inner_input, + &self.inner_fft_multiplier, + ) + }; + + // inverse FFT. we're computing a forward but we're massaging it into an inverse by conjugating the inputs and outputs + self.common_data + .inner_fft + .process_with_scratch(inner_input, inner_scratch); + + // finalize the result + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + + self.finalize_bluesteins(transmuted_inner_input, transmuted_output) + } + } + fn perform_fft_out_of_place( &self, input: &[Complex], diff --git a/src/avx/avx_mixed_radix.rs b/src/avx/avx_mixed_radix.rs index ed6de0f..38ab6ae 100644 --- a/src/avx/avx_mixed_radix.rs +++ b/src/avx/avx_mixed_radix.rs @@ -69,6 +69,41 @@ macro_rules! boilerplate_mixedradix { } } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + // Perform the column FFTs + // Safety: self.perform_column_butterflies() requires the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available + output.copy_from_slice(input); + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + + self.perform_column_butterflies(transmuted_output); + } + + self.common_data + .inner_fft + .process_with_scratch(output, scratch); + scratch[..input.len()].copy_from_slice(output); + + // Transpose + // Safety: self.transpose() requires the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &[Complex] = + array_utils::workaround_transmute(scratch); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + + self.transpose(transmuted_input, transmuted_output) + } + } + #[inline] fn perform_fft_out_of_place( &self, diff --git a/src/avx/avx_raders.rs b/src/avx/avx_raders.rs index 34b76b1..e2d0510 100644 --- a/src/avx/avx_raders.rs +++ b/src/avx/avx_raders.rs @@ -351,6 +351,75 @@ impl RadersAvx2 { } } + fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + scratch.copy_from_slice(input); + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &mut [Complex] = array_utils::workaround_transmute_mut(scratch); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + self.prepare_raders(transmuted_input, transmuted_output) + } + + let (first_input, inner_input) = scratch.split_first_mut().unwrap(); + let (first_output, inner_output) = output.split_first_mut().unwrap(); + + // perform the first of two inner FFTs + // let inner_scratch = if scratch.len() > 0 { + // &mut scratch[..] + // } else { + // &mut inner_input[..] + // }; + self.inner_fft + .process_with_scratch(inner_output, inner_input); + + // inner_output[0] now contains the sum of elements 1..n. we want the sum of all inputs, so all we need to do is add the first input + *first_output = inner_output[0] + *first_input; + + // multiply the inner result with our cached setup data + // also conjugate every entry. this sets us up to do an inverse FFT + // (because an inverse FFT is equivalent to a normal FFT where you conjugate both the inputs and outputs) + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_inner_input: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_input); + let transmuted_inner_output: &mut [Complex] = + array_utils::workaround_transmute_mut(inner_output); + avx_vector::pairwise_complex_mul_conjugated( + transmuted_inner_output, + transmuted_inner_input, + &self.twiddles, + ) + }; + + // We need to add the first input value to all output values. We can accomplish this by adding it to the DC input of our inner ifft. + // Of course, we have to conjugate it, just like we conjugated the complex multiplied above + inner_input[0] = inner_input[0] + first_input.conj(); + + // execute the second FFT + // let inner_scratch = if scratch.len() > 0 { + // scratch + // } else { + // &mut inner_output[..] + // }; + self.inner_fft + .process_with_scratch(inner_input, inner_output); + + // copy the final values into the output, reordering as we go + unsafe { + // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary + let transmuted_input: &mut [Complex] = array_utils::workaround_transmute_mut(scratch); + let transmuted_output: &mut [Complex] = + array_utils::workaround_transmute_mut(output); + self.finalize_raders(transmuted_input, transmuted_output); + } + } + fn perform_fft_out_of_place( &self, input: &mut [Complex], diff --git a/src/avx/mod.rs b/src/avx/mod.rs index dc3dbea..f3a5469 100644 --- a/src/avx/mod.rs +++ b/src/avx/mod.rs @@ -30,6 +30,50 @@ struct CommonSimdData { macro_rules! boilerplate_avx_fft { ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + let required_scratch = self.get_inplace_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_inplace_scratch_len(), + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_out_of_place_immut(in_chunk, out_chunk, scratch) + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_outofplace_scratch_len(), + scratch.len(), + ) + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -53,7 +97,7 @@ macro_rules! boilerplate_avx_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -88,7 +132,7 @@ macro_rules! boilerplate_avx_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); @@ -130,6 +174,54 @@ macro_rules! boilerplate_avx_fft { macro_rules! boilerplate_avx_fft_commondata { ($struct_name:ident) => { impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.get_inplace_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_inplace_scratch_len(), + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_out_of_place_immut(in_chunk, out_chunk, scratch) + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_outofplace_scratch_len(), + scratch.len(), + ); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -157,7 +249,7 @@ macro_rules! boilerplate_avx_fft_commondata { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -196,7 +288,7 @@ macro_rules! boilerplate_avx_fft_commondata { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/common.rs b/src/common.rs index 008be6d..0702bc4 100644 --- a/src/common.rs +++ b/src/common.rs @@ -73,6 +73,47 @@ pub fn fft_error_outofplace( macro_rules! boilerplate_fft_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.get_outofplace_scratch_len(); + if input.len() < self.len() + || output.len() != input.len() + || scratch.len() < required_scratch + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_out_of_place_immut(in_chunk, out_chunk, scratch) + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -99,7 +140,7 @@ macro_rules! boilerplate_fft_oop { return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -132,7 +173,7 @@ macro_rules! boilerplate_fft_oop { } let (scratch, extra_scratch) = scratch.split_at_mut(self.len()); - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, extra_scratch); chunk.copy_from_slice(scratch); }); @@ -175,6 +216,55 @@ macro_rules! boilerplate_fft_oop { macro_rules! boilerplate_fft { ($struct_name:ident, $len_fn:expr, $inplace_scratch_len_fn:expr, $out_of_place_scratch_len_fn:expr) => { impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + let required_scratch = self.get_outofplace_scratch_len(); + if scratch.len() < required_scratch + || input.len() < self.len() + || output.len() != input.len() + { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_outofplace_scratch_len(), + scratch.len(), + ); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let scratch = &mut scratch[..required_scratch]; + let result = array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_out_of_place_immut(in_chunk, out_chunk, scratch) + }, + ); + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + self.get_outofplace_scratch_len(), + scratch.len(), + ); + } + } + fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -202,7 +292,7 @@ macro_rules! boilerplate_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -241,7 +331,7 @@ macro_rules! boilerplate_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/lib.rs b/src/lib.rs index 74b4787..f4d947e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -236,6 +236,13 @@ pub trait Fft: Length + Direction + Sync + Send { scratch: &mut [Complex], ); + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ); + /// Returns the size of the scratch buffer required by `process_with_scratch` /// /// For most FFT sizes, this method will return `self.len()`. For a few small sizes it will return 0, and for some special FFT sizes diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index dc84ffa..36317a4 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -43,7 +43,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -55,7 +55,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); @@ -64,7 +64,7 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { output, 2 * self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_parallel_fft_contiguous(DoubleBuf { input: input_slice, @@ -73,10 +73,10 @@ macro_rules! boilerplate_fft_neon_f32_butterfly { }, ); if alldone.is_err() && input.len() >= self.len() { - let input_slice = workaround_transmute_mut(input); + let input_slice = crate::array_utils::workaround_transmute(input); let output_slice = workaround_transmute_mut(output); self.perform_fft_contiguous(DoubleBuf { - input: &mut input_slice[len - self.len()..], + input: &input_slice[len - self.len()..], output: &mut output_slice[len - self.len()..], }) } @@ -101,7 +101,7 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -110,11 +110,11 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { //#[target_feature(enable = "neon")] pub(crate) unsafe fn perform_oop_fft_butterfly_multi( &self, - input: &mut [Complex], + input: &[Complex], output: &mut [Complex], ) -> Result<(), ()> { array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { - let input_slice = workaround_transmute_mut(in_chunk); + let input_slice = crate::array_utils::workaround_transmute(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { input: input_slice, @@ -130,6 +130,25 @@ macro_rules! boilerplate_fft_neon_f64_butterfly { macro_rules! boilerplate_fft_neon_common_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + let result = unsafe { self.perform_oop_fft_butterfly_multi(input, output) }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], diff --git a/src/neon/neon_common.rs b/src/neon/neon_common.rs index 826a320..f0d3abe 100644 --- a/src/neon/neon_common.rs +++ b/src/neon/neon_common.rs @@ -57,6 +57,39 @@ macro_rules! separate_interleaved_complex_f32 { macro_rules! boilerplate_fft_neon_oop { ($struct_name:ident, $len_fn:expr) => { impl Fft for $struct_name { + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + _scratch: &mut [Complex], + ) { + if self.len() == 0 { + return; + } + + if input.len() < self.len() || output.len() != input.len() { + // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here + } + + let result = unsafe { + array_utils::iter_chunks_zipped( + input, + output, + self.len(), + |in_chunk, out_chunk| { + self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) + }, + ) + }; + + if result.is_err() { + // We want to trigger a panic, because the buffer sizes weren't cleanly divisible by the FFT size, + // but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us + fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + } + } fn process_outofplace_with_scratch( &self, input: &mut [Complex], @@ -74,7 +107,7 @@ macro_rules! boilerplate_fft_neon_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +142,7 @@ macro_rules! boilerplate_fft_neon_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -180,7 +213,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -219,7 +252,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/neon/neon_radix4.rs b/src/neon/neon_radix4.rs index 841a782..aeff1a4 100644 --- a/src/neon/neon_radix4.rs +++ b/src/neon/neon_radix4.rs @@ -83,6 +83,15 @@ impl NeonRadix4 { } } + unsafe fn perform_fft_out_of_place_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + todo!() + } + //#[target_feature(enable = "neon")] unsafe fn perform_fft_out_of_place( &self, diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 97c6bcf..b83f642 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -48,7 +48,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -65,7 +65,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); - let alldone = array_utils::iter_chunks_zipped( + let alldone = array_utils::iter_chunks_zipped_mut( input, output, 2 * self.len(), @@ -107,7 +107,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -119,7 +119,7 @@ macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { input: &mut [Complex], output: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { + array_utils::iter_chunks_zipped_mut(input, output, self.len(), |in_chunk, out_chunk| { let input_slice = workaround_transmute_mut(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { @@ -147,7 +147,7 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -159,7 +159,7 @@ macro_rules! boilerplate_fft_sse_f64_butterfly { input: &mut [Complex], output: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { + array_utils::iter_chunks_zipped_mut(input, output, self.len(), |in_chunk, out_chunk| { let input_slice = workaround_transmute_mut(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { diff --git a/src/sse/sse_common.rs b/src/sse/sse_common.rs index 0c31995..a667872 100644 --- a/src/sse/sse_common.rs +++ b/src/sse/sse_common.rs @@ -74,7 +74,7 @@ macro_rules! boilerplate_fft_sse_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +109,7 @@ macro_rules! boilerplate_fft_sse_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) @@ -180,7 +180,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks_zipped( + let result = array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -219,7 +219,7 @@ macro_rules! boilerplate_sse_fft { } let scratch = &mut scratch[..required_scratch]; - let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { + let result = array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_inplace(chunk, scratch) }); diff --git a/src/test_utils.rs b/src/test_utils.rs index 1bbb48b..4a60f19 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -213,6 +213,20 @@ impl Fft for BigScratchAlgorithm { fn get_outofplace_scratch_len(&self) -> usize { self.outofplace_scratch } + + fn process_outofplace_with_scratch_immut( + &self, + input: &[Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + assert!( + scratch.len() >= self.outofplace_scratch, + "Not enough OOP scratch provided, self={:?}, provided scratch={}", + &self, + scratch.len() + ); + } } impl Length for BigScratchAlgorithm { fn len(&self) -> usize { diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 90af2e8..5b5ec87 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -46,7 +46,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { buffer: &mut [Complex], ) -> Result<(), ()> { let len = buffer.len(); - let alldone = array_utils::iter_chunks(buffer, 2 * self.len(), |chunk| { + let alldone = array_utils::iter_chunks_mut(buffer, 2 * self.len(), |chunk| { self.perform_parallel_fft_butterfly(chunk) }); if alldone.is_err() && buffer.len() >= self.len() { @@ -63,7 +63,7 @@ macro_rules! boilerplate_fft_wasm_simd_f32_butterfly { output: &mut [Complex], ) -> Result<(), ()> { let len = input.len(); - let alldone = array_utils::iter_chunks_zipped( + let alldone = array_utils::iter_chunks_zipped_mut( input, output, 2 * self.len(), @@ -105,7 +105,7 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { &self, buffer: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_butterfly(chunk) }) } @@ -117,7 +117,7 @@ macro_rules! boilerplate_fft_wasm_simd_f64_butterfly { input: &mut [Complex], output: &mut [Complex], ) -> Result<(), ()> { - array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { + array_utils::iter_chunks_zipped_mut(input, output, self.len(), |in_chunk, out_chunk| { let input_slice = workaround_transmute_mut(in_chunk); let output_slice = workaround_transmute_mut(out_chunk); self.perform_fft_contiguous(DoubleBuf { diff --git a/src/wasm_simd/wasm_simd_common.rs b/src/wasm_simd/wasm_simd_common.rs index 787170d..4c7a20c 100644 --- a/src/wasm_simd/wasm_simd_common.rs +++ b/src/wasm_simd/wasm_simd_common.rs @@ -74,7 +74,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { } let result = unsafe { - array_utils::iter_chunks_zipped( + array_utils::iter_chunks_zipped_mut( input, output, self.len(), @@ -109,7 +109,7 @@ macro_rules! boilerplate_fft_wasm_simd_oop { let scratch = &mut scratch[..required_scratch]; let result = unsafe { - array_utils::iter_chunks(buffer, self.len(), |chunk| { + array_utils::iter_chunks_mut(buffer, self.len(), |chunk| { self.perform_fft_out_of_place(chunk, scratch, &mut []); chunk.copy_from_slice(scratch); }) diff --git a/tests/test_basic.rs b/tests/test_basic.rs new file mode 100644 index 0000000..18f3aad --- /dev/null +++ b/tests/test_basic.rs @@ -0,0 +1,26 @@ +use num_complex::Complex; +use rustfft::{Fft, FftPlanner}; +use std::sync::Arc; + +// Just a very simple test to help debugging +#[test] +fn test_100() { + // let len = 102; + for len in 37..5000 { + dbg!(len); + let mut p = FftPlanner::::new(); + let planner: Arc> = p.plan_fft_forward(len); + + let mut input: Vec<_> = (0..len).map(|i| Complex::new(i as f32, 0.0)).collect(); + // dbg!(planner.get_inplace_scratch_len(), planner.get_outofplace_scratch_len()); + let mut output = input.clone(); + let mut output2 = output.clone(); + let mut scratch = vec![Complex::::ZERO; planner.get_inplace_scratch_len() + len]; + // planner.process_outofplace_with_scratch(&mut input, &mut output, &mut scratch); + planner.process_outofplace_with_scratch_immut(&input, &mut output, &mut scratch); + + planner.process_outofplace_with_scratch(&mut input, &mut output2, &mut scratch); + + assert_eq!(output, output2); + } +}