Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ jobs:
with:
version: "latest"
- name: Run test suites with wasm-pack
run: wasm-pack test --node --lib --features wasm_simd
run: wasm-pack test --node --features wasm_simd
48 changes: 33 additions & 15 deletions src/algorithm/raders_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_complex::Complex;
use num_integer::Integer;
use num_traits::Zero;
use primal_check::miller_rabin;
use strength_reduce::StrengthReducedUsize;
use strength_reduce::StrengthReducedU64;

use crate::math_utils;
use crate::{common::FftNum, twiddles, FftDirection};
Expand Down Expand Up @@ -42,10 +42,10 @@ pub struct RadersAlgorithm<T> {
inner_fft: Arc<dyn Fft<T>>,
inner_fft_data: Box<[Complex<T>]>,

primitive_root: usize,
primitive_root_inverse: usize,
primitive_root: u64,
primitive_root_inverse: u64,

len: StrengthReducedUsize,
len: StrengthReducedU64,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
immut_scratch_len: usize,
Expand All @@ -68,10 +68,10 @@ impl<T: FftNum> RadersAlgorithm<T> {
assert!(miller_rabin(len as u64), "For raders algorithm, inner_fft.len() + 1 must be prime. Expected prime number, got {} + 1 = {}", inner_fft_len, len);

let direction = inner_fft.fft_direction();
let reduced_len = StrengthReducedUsize::new(len);
let reduced_len = StrengthReducedU64::new(len as u64);

// compute the primitive root and its inverse for this size
let primitive_root = math_utils::primitive_root(len as u64).unwrap() as usize;
let primitive_root = math_utils::primitive_root(len as u64).unwrap();

// compute the multiplicative inverse of primative_root mod len and vice versa.
// i64::extended_gcd will compute both the inverse of left mod right, and the inverse of right mod left, but we're only goingto use one of them
Expand All @@ -81,7 +81,7 @@ impl<T: FftNum> RadersAlgorithm<T> {
gcd_data.x
} else {
gcd_data.x + len as i64
} as usize;
} as u64;

// precompute the coefficients to use inside the process method
let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
Expand All @@ -91,7 +91,8 @@ impl<T: FftNum> RadersAlgorithm<T> {
let twiddle = twiddles::compute_twiddle(twiddle_input, len, direction);
*input_cell = twiddle * inner_fft_scale;

twiddle_input = (twiddle_input * primitive_root_inverse) % reduced_len;
twiddle_input =
((twiddle_input as u64 * primitive_root_inverse) % reduced_len) as usize;
}

let required_inner_scratch = inner_fft.get_inplace_scratch_len();
Expand Down Expand Up @@ -136,7 +137,7 @@ impl<T: FftNum> RadersAlgorithm<T> {
// copy the input into the scratch space, reordering as we go
let mut input_index = 1;
for output_element in scratch.iter_mut() {
input_index = (input_index * self.primitive_root) % self.len;
input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;

let input_element = input[input_index - 1];
*output_element = input_element;
Expand Down Expand Up @@ -164,7 +165,8 @@ impl<T: FftNum> RadersAlgorithm<T> {
// copy the final values into the output, reordering as we go
let mut output_index = 1;
for scratch_element in scratch {
output_index = (output_index * self.primitive_root_inverse) % self.len;
output_index =
((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
output[output_index - 1] = scratch_element.conj();
}
}
Expand All @@ -182,7 +184,7 @@ impl<T: FftNum> RadersAlgorithm<T> {
// copy the input into the output, reordering as we go. also compute a sum of all elements
let mut input_index = 1;
for output_element in output.iter_mut() {
input_index = (input_index * self.primitive_root) % self.len;
input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;

let input_element = input[input_index - 1];
*output_element = input_element;
Expand Down Expand Up @@ -225,7 +227,8 @@ impl<T: FftNum> RadersAlgorithm<T> {
// copy the final values into the output, reordering as we go
let mut output_index = 1;
for input_element in input {
output_index = (output_index * self.primitive_root_inverse) % self.len;
output_index =
((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
output[output_index - 1] = input_element.conj();
}
}
Expand All @@ -239,7 +242,7 @@ impl<T: FftNum> RadersAlgorithm<T> {
// copy the buffer into the scratch, reordering as we go. also compute a sum of all elements
let mut input_index = 1;
for scratch_element in scratch.iter_mut() {
input_index = (input_index * self.primitive_root) % self.len;
input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;

let buffer_element = buffer[input_index - 1];
*scratch_element = buffer_element;
Expand Down Expand Up @@ -273,14 +276,15 @@ impl<T: FftNum> RadersAlgorithm<T> {
// copy the final values into the output, reordering as we go
let mut output_index = 1;
for scratch_element in scratch {
output_index = (output_index * self.primitive_root_inverse) % self.len;
output_index =
((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
buffer[output_index - 1] = scratch_element.conj();
}
}
}
boilerplate_fft!(
RadersAlgorithm,
|this: &RadersAlgorithm<_>| this.len.get(),
|this: &RadersAlgorithm<_>| this.len.get() as usize,
|this: &RadersAlgorithm<_>| this.inplace_scratch_len,
|this: &RadersAlgorithm<_>| this.outofplace_scratch_len,
|this: &RadersAlgorithm<_>| this.immut_scratch_len
Expand All @@ -291,6 +295,7 @@ mod unit_tests {
use super::*;
use crate::algorithm::Dft;
use crate::test_utils::check_fft_algorithm;
use crate::FftPlanner;
use std::sync::Arc;

#[test]
Expand All @@ -303,6 +308,19 @@ mod unit_tests {
}
}

#[test]
fn test_raders_32bit_overflow() {
// Construct and use Raders instances for a few large primes
// that could panic due to overflow errors on 32-bit builds.
let mut planner = FftPlanner::<f32>::new();
for len in [112501, 216569, 417623] {
let inner_fft = planner.plan_fft_forward(len - 1);
let fft: RadersAlgorithm<f32> = RadersAlgorithm::new(inner_fft);
let mut data = vec![Complex::new(0.0, 0.0); len];
fft.process(&mut data);
}
}

fn test_raders_with_length(len: usize, direction: FftDirection) {
let inner_fft = Arc::new(Dft::new(len - 1, direction));
let fft = RadersAlgorithm::new(inner_fft);
Expand Down
7 changes: 4 additions & 3 deletions src/wasm_simd/wasm_simd_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ macro_rules! interleave_complex_f32 {

/// Shuffle elements to interleave two contiguous sets of f32, from an array of simd vectors to a new array of simd vectors
/// This statement:
/// ```
///
/// let values = separate_interleaved_complex_f32!(input, {0, 2, 4});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// extract_lo_lo_f32(input[0], input[1]),
/// extract_lo_lo_f32(input[2], input[3]),
Expand All @@ -46,6 +46,7 @@ macro_rules! interleave_complex_f32 {
/// extract_hi_hi_f32(input[2], input[3]),
/// extract_hi_hi_f32(input[4], input[5]),
/// ];
///
macro_rules! separate_interleaved_complex_f32 {
($input:ident, { $($idx:literal),* }) => {
[
Expand Down
71 changes: 35 additions & 36 deletions src/wasm_simd/wasm_simd_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@ use super::WasmNum;
/// Read these indexes from an WasmSimdArray and build an array of simd vectors.
/// Takes a name of a vector to read from, and a list of indexes to read.
/// This statement:
/// ```
///
/// let values = read_complex_to_array!(input, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// input.load_complex(0),
/// input.load_complex(1),
/// input.load_complex(2),
/// input.load_complex(3),
/// ];
/// ```
macro_rules! read_complex_to_array {
($input:ident, { $($idx:literal),* }) => {
[
Expand All @@ -36,18 +35,18 @@ macro_rules! read_complex_to_array {
/// Read these indexes from an WasmSimdArray and build an array or partially filled simd vectors.
/// Takes a name of a vector to read from, and a list of indexes to read.
/// This statement:
/// ```
///
/// let values = read_partial1_complex_to_array!(input, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// input.load1_complex(0),
/// input.load1_complex(1),
/// input.load1_complex(2),
/// input.load1_complex(3),
/// ];
/// ```
///
macro_rules! read_partial1_complex_to_array {
($input:ident, { $($idx:literal),* }) => {
[
Expand All @@ -61,18 +60,18 @@ macro_rules! read_partial1_complex_to_array {
/// Write these indexes of an array of simd vectors to the same indexes of an WasmSimdArray.
/// Takes a name of a vector to read from, one to write to, and a list of indexes.
/// This statement:
/// ```
///
/// let values = write_complex_to_array!(input, output, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// output.store_complex(input[0], 0),
/// output.store_complex(input[1], 1),
/// output.store_complex(input[2], 2),
/// output.store_complex(input[3], 3),
/// ];
/// ```
///
macro_rules! write_complex_to_array {
($input:ident, $output:ident, { $($idx:literal),* }) => {
$(
Expand All @@ -84,18 +83,18 @@ macro_rules! write_complex_to_array {
/// Write the low half of these indexes of an array of simd vectors to the same indexes of an WasmSimdArray.
/// Takes a name of a vector to read from, one to write to, and a list of indexes.
/// This statement:
/// ```
///
/// let values = write_partial_lo_complex_to_array!(input, output, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// output.store_partial_lo_complex(input[0], 0),
/// output.store_partial_lo_complex(input[1], 1),
/// output.store_partial_lo_complex(input[2], 2),
/// output.store_partial_lo_complex(input[3], 3),
/// ];
/// ```
///
macro_rules! write_partial_lo_complex_to_array {
($input:ident, $output:ident, { $($idx:literal),* }) => {
$(
Expand All @@ -107,18 +106,18 @@ macro_rules! write_partial_lo_complex_to_array {
/// Write these indexes of an array of simd vectors to the same indexes, multiplied by a stride, of an WasmSimdArray.
/// Takes a name of a vector to read from, one to write to, an integer stride, and a list of indexes.
/// This statement:
/// ```
///
/// let values = write_complex_to_array_strided!(input, output, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// output.store_complex(input[0], 0),
/// output.store_complex(input[1], 2),
/// output.store_complex(input[2], 4),
/// output.store_complex(input[3], 6),
/// ];
/// ```
///
macro_rules! write_complex_to_array_strided {
($input:ident, $output:ident, $stride:literal, { $($idx:literal),* }) => {
$(
Expand All @@ -130,18 +129,18 @@ macro_rules! write_complex_to_array_strided {
/// Read these indexes from an WasmSimdArray and build an array of simd vectors.
/// Takes a name of a vector to read from, and a list of indexes to read.
/// This statement:
/// ```
///
/// let values = read_complex_to_array_v128!(input, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// input.load_complex_v128(0),
/// input.load_complex_v128(1),
/// input.load_complex_v128(2),
/// input.load_complex_v128(3),
/// ];
/// ```
///
macro_rules! read_complex_to_array_v128 {
($input:ident, { $($idx:literal),* }) => {
[
Expand All @@ -155,18 +154,18 @@ macro_rules! read_complex_to_array_v128 {
/// Read these indexes from an WasmSimdArray and build an array or partially filled simd vectors.
/// Takes a name of a vector to read from, and a list of indexes to read.
/// This statement:
/// ```
///
/// let values = read_partial1_complex_to_array_v128!(input, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// input.load1_complex_v128(0),
/// input.load1_complex_v128(1),
/// input.load1_complex_v128(2),
/// input.load1_complex_v128(3),
/// ];
/// ```
///
macro_rules! read_partial1_complex_to_array_v128 {
($input:ident, { $($idx:literal),* }) => {
[
Expand All @@ -180,18 +179,18 @@ macro_rules! read_partial1_complex_to_array_v128 {
/// Write these indexes of an array of simd vectors to the same indexes of an WasmSimdArray.
/// Takes a name of a vector to read from, one to write to, and a list of indexes.
/// This statement:
/// ```
///
/// let values = write_complex_to_array_v128!(input, output, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// output.store_complex_v128(input[0], 0),
/// output.store_complex_v128(input[1], 1),
/// output.store_complex_v128(input[2], 2),
/// output.store_complex_v128(input[3], 3),
/// ];
/// ```
///
macro_rules! write_complex_to_array_v128 {
($input:ident, $output:ident, { $($idx:literal),* }) => {
$(
Expand All @@ -203,18 +202,18 @@ macro_rules! write_complex_to_array_v128 {
/// Write these indexes of an array of simd vectors to the same indexes, multiplied by a stride, of an WasmSimdArray.
/// Takes a name of a vector to read from, one to write to, an integer stride, and a list of indexes.
/// This statement:
/// ```
///
/// let values = write_complex_to_array_strided_v128!(input, output, {0, 1, 2, 3});
/// ```
///
/// is equivalent to:
/// ```
///
/// let values = [
/// output.store_complex_v128(input[0], 0),
/// output.store_complex_v128(input[1], 2),
/// output.store_complex_v128(input[2], 4),
/// output.store_complex_v128(input[3], 6),
/// ];
/// ```
///
macro_rules! write_complex_to_array_strided_v128 {
($input:ident, $output:ident, $stride:literal, { $($idx:literal),* }) => {
$(
Expand Down