From af7aa905ac67e4ceb75f6a383145a8a6f99b6e7d Mon Sep 17 00:00:00 2001 From: George Woltman Date: Mon, 7 Apr 2025 21:55:07 +0000 Subject: [PATCH 001/115] Minor optimization. Compute weights while waiting for carries to be ready. --- src/cl/carryfused.cl | 56 +++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index d0623dee..bf905597 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -134,49 +134,58 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // but it's fine either way. if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } -#if OLD_FENCE - + // Tell next line that its carries are ready if (gr < H) { +#if OLD_FENCE // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); write_mem_fence(CLK_GLOBAL_MEM_FENCE); bar(); - if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } - } - if (gr == 0) { return; } - if (me == 0) { do { spin(); } while(!atomic_load((atomic_uint *) &ready[gr - 1])); } - // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); - bar(); - read_mem_fence(CLK_GLOBAL_MEM_FENCE); - - // Clear carry ready flag for next iteration - if (me == 0) ready[gr - 1] = 0; - #else - - if (gr < H) { write_mem_fence(CLK_GLOBAL_MEM_FENCE); if (me % WAVEFRONT == 0) { u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; atomic_store((atomic_uint *) &ready[pos], 1); } +#endif } + + // Line zero will be redone when gr == H if (gr == 0) { return; } + + // Do some work while our carries may not be ready #if HAS_ASM __asm("s_setprio 0"); #endif + + // Calculate inverse weights + T base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + T weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + T weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + u[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; if (me % WAVEFRONT == 0) { do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); } -#if HAS_ASM - __asm("s_setprio 1"); -#endif mem_fence(CLK_GLOBAL_MEM_FENCE); - // Clear carry ready flag for next iteration if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; #endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. // The new carry layout lets the compiler generate global_load_dwordx4 instructions. @@ -210,14 +219,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( bool biglit0 = test(b, 2 * i); #endif wu[i] = carryFinal(wu[i], carry[i], biglit0); - } - - T base = optionalHalve(weights.y); - - for (u32 i = 0; i < NW; ++i) { - T weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); - T weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); - u[i] = U2(wu[i].x, wu[i].y) * U2(weight1, weight2); + u[i] *= U2(wu[i].x, wu[i].y); } bar(); From fa3319274be26f46e84d8a0dcadb0bb85d106888 Mon Sep 17 00:00:00 2001 From: George Woltman Date: Mon, 7 Apr 2025 21:58:38 +0000 Subject: [PATCH 002/115] Fixed bug were default TAIL_KERNELS in C code did mot match default in OpenCL code. --- src/cl/tailsquare.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index 078259a3..30105619 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -18,7 +18,7 @@ // 2 = double wide, single kernel // 3 = double wide, two kernels #if !defined(TAIL_KERNELS) -#define TAIL_KERNELS 3 // Default is double-wide tailSquare with two kernels +#define TAIL_KERNELS 2 // Default is double-wide tailSquare with two kernels #endif #define SINGLE_WIDE TAIL_KERNELS < 2 // Old single-wide tailSquare vs. new double-wide tailSquare #define SINGLE_KERNEL (TAIL_KERNELS & 1) == 0 // TailSquare uses a single kernel vs. two kernels From dac928d999af104dd6402d1d3f6d99a64bf4c961 Mon Sep 17 00:00:00 2001 From: George Woltman Date: Wed, 9 Apr 2025 02:16:03 +0000 Subject: [PATCH 003/115] Eliminated BCAST=1 and UNROLL_W=3 hack. Replaced with a 3 digit FFT specification. --- src/Args.cpp | 2 +- src/FFTConfig.cpp | 16 +++++++++-- src/FFTConfig.h | 17 ++++++++++-- src/Gpu.cpp | 1 - src/cl/base.cl | 32 +++++++++++++++------- src/cl/fft-middle.cl | 2 +- src/cl/fftbase.cl | 63 ++++++++++++++++---------------------------- src/cl/fftheight.cl | 19 ++++++++----- src/cl/fftwidth.cl | 19 ++++++++----- src/tune.cpp | 4 +-- 10 files changed, 103 insertions(+), 72 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 52fdebba..d5a11064 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -295,7 +295,7 @@ void Args::parse(const string& line) { } log(" FFT | BPW | Max exp (M)\n"); for (const FFTShape& shape : FFTShape::multiSpec(s)) { - for (u32 variant = 0; variant < FFTConfig::N_VARIANT; ++variant) { + for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { FFTConfig fft{shape, variant, CARRY_AUTO}; log("%12s | %.2f | %5.1f\n", fft.spec().c_str(), fft.maxBpw(), fft.maxExp() / 1'000'000.0); } diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index acb595e0..d5f7cc48 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -15,6 +15,16 @@ #include #include + + +// BUG, BUG, BUG: We need to redo calculations for maxBPW for FFT variants. There are more variants now. Can we derive a formula for each variation +// as a percentage of the best variant? Ex: if variant 313 is the most accurate, dropping to 213 might be 99.92% of max, 113 might be 99.86% of max, +// 303 might be 99.75% of max, etc. Warning: the percentage change from changing the middle variant will be different for every middle size. +// This hack crudely maps new variant numbers to old variant numbers +#define vhack(x) ((variant_W(x) >= 2 ? 2 : 0) + (variant_M(x) == 1 ? 1 : 0)) + + + using namespace std; struct FftBpw { @@ -170,7 +180,9 @@ FFTConfig::FFTConfig(FFTShape shape, u32 variant, u32 carry) : variant{variant}, carry{carry} { - assert(variant < N_VARIANT); + assert(variant_W() < N_VARIANT_W); + assert(variant_M() < N_VARIANT_M); + assert(variant_H() < N_VARIANT_H); } string FFTConfig::spec() const { @@ -179,7 +191,7 @@ string FFTConfig::spec() const { } double FFTConfig::maxBpw() const { - double b = shape.bpw[variant]; + double b = shape.bpw[vhack(variant)]; return carry == CARRY_32 ? std::min(shape.carry32BPW(), b) : b; } diff --git a/src/FFTConfig.h b/src/FFTConfig.h index f8912c98..55e41534 100644 --- a/src/FFTConfig.h +++ b/src/FFTConfig.h @@ -26,7 +26,7 @@ class FFTShape { static tuple getChainLengths(u32 fftSize, u32 exponent, u32 middle); static vector multiSpec(const string& spec); - + u32 width = 0; u32 middle = 0; u32 height = 0; @@ -47,11 +47,24 @@ class FFTShape { bool needsLargeCarry(u32 E) const; }; +static const u32 N_VARIANT_W = 4; +static const u32 N_VARIANT_M = 2; +static const u32 N_VARIANT_H = 4; +static const u32 LAST_VARIANT = (N_VARIANT_W - 1) * 100 + (N_VARIANT_M - 1) * 10 + N_VARIANT_H - 1; +inline u32 variant_WMH(u32 v_W, u32 v_M, u32 v_H) { return v_W * 100 + v_M * 10 + v_H; } +inline u32 variant_W(u32 v) { return v / 100; } +inline u32 variant_M(u32 v) { return v % 100 / 10; } +inline u32 variant_H(u32 v) { return v % 10; } +inline u32 next_variant(u32 v) { u32 new_v; + new_v = v + 1; if (variant_H (new_v) < N_VARIANT_H) return (new_v); + new_v = (v / 10 + 1) * 10; if (variant_M (new_v) < N_VARIANT_M) return (new_v); + new_v = (v / 100 + 1) * 100; return (new_v); +} + enum CARRY_KIND { CARRY_32=0, CARRY_64=1, CARRY_AUTO=2}; struct FFTConfig { public: - static const u32 N_VARIANT = 4; static FFTConfig bestFit(const Args& args, u32 E, const std::string& spec); FFTShape shape{}; diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 2b6927fd..0f8d5874 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -221,7 +221,6 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< "NO_ASM", "DEBUG", "CARRY64", - "BCAST", "BIGLIT", "NONTEMPORAL", "PAD", diff --git a/src/cl/base.cl b/src/cl/base.cl index 0c707044..ec55308d 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -68,24 +68,36 @@ G_H "group height" == SMALL_HEIGHT / NH #define OLD_FENCE 1 #endif -// Nonteporal reads and writes might be a little bit faster on many GPUs by keeping more reusable data in the caches. +// Nontemporal reads and writes might be a little bit faster on many GPUs by keeping more reusable data in the caches. // However, on those GPUs with large caches there should be a significant speed gain from keeping FFT data in the caches. // Default to the big win when caching is beneficial rather than the tiny gain when non-temporal is better. #if !defined(NONTEMPORAL) #define NONTEMPORAL 0 #endif -#if FFT_VARIANT > 3 -#error FFT_VARIANT must be between 0 and 3 +// FFT variant is in 3 parts. One digit for WIDTH, one digit for MIDDLE, one digit for HEIGHT. +// For WIDTH and HEIGHT there are 4 variants: +// 0 compute one trig, bcast, chainmul previously was :even/:odd BCAST=1 +// 1 read one trig, with old chainmul previously was :0/:1 +// 2 read all trigs, no chainmul previously was :2/:3 +// 3 read all trigs, sin/cos format for more FMA previously was :2/:3 UNROLL_W=3 +// Note smaller numbers above do more F64 and are less accurate, larger numbers have more memory accesses +// For MIDDLE there are two variants: +// 0 chainmul +// 1 lots of computing trigs, very short chainmul for maximum accuracy previously was :1/:3 +#define FFT_VARIANT_W (FFT_VARIANT / 100) +#define FFT_VARIANT_M (FFT_VARIANT % 100 / 10) +#define FFT_VARIANT_H (FFT_VARIANT % 10) +#if FFT_VARIANT_W > 3 +#error FFT_VARIANT_W must be between 0 and 3 +#endif +#if FFT_VARIANT_M > 1 +#error FFT_VARIANT_M must be between 0 and 1 +#endif +#if FFT_VARIANT_H > 3 +#error FFT_VARIANT_H must be between 0 and 3 #endif -#if defined(TRIG_HI) || defined(CLEAN) -#error Use FFT_VARIANT instead of TRIG_HI or CLEAN -#endif - -#define TRIG_HI (FFT_VARIANT & 1) -#define CLEAN (FFT_VARIANT >> 1) - #if !defined(UNROLL_W) #if AMDGPU #define UNROLL_W 0 diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index e10b5764..bf61d555 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -84,7 +84,7 @@ void fft_MIDDLE(T2 *u) { // Keep in sync with TrigBufCache.cpp, see comment there. #define SHARP_MIDDLE 5 -#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && TRIG_HI +#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 1 #define MM_CHAIN 1 #define MM2_CHAIN 2 #endif diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index ac0d5c9e..bda7678b 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -84,7 +84,7 @@ void chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { } -#if BCAST +#if AMDGPU int bcast4(int x) { return __builtin_amdgcn_mov_dpp(x, 0, 0xf, 0xf, false); } int bcast8(int x) { return __builtin_amdgcn_ds_swizzle(x, 0x0018); } @@ -182,13 +182,14 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } } -void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { +void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me, bool chainmul) { #if 0 u32 p = me / f * f; #else u32 p = me & ~(f - 1); #endif +// Compute trigs from scratch every time. This can't possibly be a good idea on any GPUs. #if 0 T2 w = slowTrig_N(ND / n / WG * p, ND / n); T2 base = w; @@ -196,57 +197,37 @@ void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { u[i] = cmul(u[i], w); w = cmul(w, base); } + return; #endif -// Theoretically, maximum accuracy. Uses memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. -#if CLEAN == 1 // Radeon VII loves this case, in fact it is faster than the CLEAN == 0 case. nVidia Titan V hates this case. - - T2 w = trig[p]; - - if (n >= 8) { - u[1] = cmulFancy(u[1], w); - } else { - u[1] = cmul(u[1], w); - } +// This code uses chained complex multiplies which could be faster on GPUs with great DP throughput or poor memory bandwidth or caching. +// This ought to be the least accurate version of Tabmul. In practice this is more accurate (at least when n==8) than reading precomputed +// values from memory. Perhaps chained Fancy muls are the reason (or was resolved when the algorithm to precompute trig values changed). - for (u32 i = 2; i < n; ++i) { - T2 base = trig[(i-1)*WG + p]; - u[i] = cmul(u[i], base); + if (chainmul) { + T2 w = trig[p]; + chainMul (n, u, w, 0); + return; } -// Original CLEAN==1, saves one cmul at the cost of a memory access. I see little use for this case. -#elif 0 - T2 w = trig[p]; - - if (n >= 8) { - u[1] = cmulFancy(u[1], w); - } else { - u[1] = cmul(u[1], w); - } +// Theoretically, maximum accuracy. Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. +// Radeon VII loves this case, it is faster than the chainmul case. nVidia Titan V hates this case. - T2 base = trig[WG + p]; + if (!chainmul) { + T2 w = trig[p]; - if (n >= 8) { - for (u32 i = 2; i < n; ++i) { - u[i] = cmul(u[i], base); - base = cmulFancy(base, w); + if (n >= 8) { + u[1] = cmulFancy(u[1], w); + } else { + u[1] = cmul(u[1], w); } - } else { + for (u32 i = 2; i < n; ++i) { + T2 base = trig[(i-1)*WG + p]; u[i] = cmul(u[i], base); - base = cmul(base, w); } + return; } - -// This code uses chained complex multiplies which could be faster on GPUs with great DP throughput or poor memory bandwidth or caching. -// This ought to be the least accurate version of Tabmul. In practice this is more accurate (at least when n==8) than reading precomputed -// values from memory. Perhaps chained Fancy muls are the reason. -#elif CLEAN == 0 - T2 w = trig[p]; - chainMul (n, u, w, 0); -#else -#error CLEAN must be 0 or 1 -#endif } diff --git a/src/cl/fftheight.cl b/src/cl/fftheight.cl index ed82b79f..7199017c 100644 --- a/src/cl/fftheight.cl +++ b/src/cl/fftheight.cl @@ -20,7 +20,14 @@ void fft_NH(T2 *u) { #endif } -#if BCAST && (HEIGHT <= 1024) +#if FFT_VARIANT_H == 0 + +#if HEIGHT > 1024 +#error FFT_VARIANT_H == 0 only supports HEIGHT <= 1024 +#endif +#if !AMDGPU +#error FFT_VARIANT_H == 0 only supported by AMD GPUs +#endif void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { @@ -61,7 +68,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { if (s > 1) { bar(); } fft_NH(u); - tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me, FFT_VARIANT_H == 1); shufl(SMALL_HEIGHT / NH, lds, u, NH, s); } fft_NH(u); @@ -78,7 +85,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < WG; s *= NH) { if (s > 1) { bar(WG); } fft_NH(u); - tabMul(WG, trig, u, NH, s, me % WG); + tabMul(WG, trig, u, NH, s, me % WG, FFT_VARIANT_H == 1); shufl2(WG, lds, u, NH, s); } fft_NH(u); @@ -97,7 +104,7 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { // Custom code for various SMALL_HEIGHT values -#if SMALL_HEIGHT == 256 && NH == 4 && !BCAST && CLEAN == 1 +#if SMALL_HEIGHT == 256 && NH == 4 && FFT_VARIANT_H == 3 // Custom code for SMALL_HEIGHT=256, NH=4 @@ -127,7 +134,7 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { // Finish third tabMul and perform final fft4. finish_tabMul4_fft4(WG, partitioned_lds, trig, preloads, u, 16, me, 1); -#elif SMALL_HEIGHT == 512 && NH == 8 && !BCAST && CLEAN == 1 +#elif SMALL_HEIGHT == 512 && NH == 8 && FFT_VARIANT_H == 3 // Custom code for SMALL_HEIGHT=512, NH=8 @@ -151,7 +158,7 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { // Finish second tabMul and perform final fft8. finish_tabMul8_fft8(WG, partitioned_lds, trig, preloads, u, 8, me, 1); -#elif SMALL_HEIGHT == 1024 && NH == 4 && !BCAST && CLEAN == 1 +#elif SMALL_HEIGHT == 1024 && NH == 4 && FFT_VARIANT_H == 3 // Custom code for SMALL_HEIGHT=1024, NH=4 diff --git a/src/cl/fftwidth.cl b/src/cl/fftwidth.cl index 4f34161a..d19a2345 100644 --- a/src/cl/fftwidth.cl +++ b/src/cl/fftwidth.cl @@ -18,7 +18,14 @@ void fft_NW(T2 *u) { #endif } -#if BCAST && (WIDTH <= 1024) +#if FFT_VARIANT_W == 0 + +#if WIDTH > 1024 +#error FFT_VARIANT_W == 0 only supports WIDTH <= 1024 +#endif +#if !AMDGPU +#error FFT_VARIANT_W == 0 only supported by AMD GPUs +#endif void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { u32 me = get_local_id(0); @@ -51,7 +58,7 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { for (u32 s = 1; s < WIDTH / NW; s *= NW) { if (s > 1) { bar(); } fft_NW(u); - tabMul(WIDTH / NW, trig, u, NW, s, me); + tabMul(WIDTH / NW, trig, u, NW, s, me, FFT_VARIANT_W == 1); shufl( WIDTH / NW, lds, u, NW, s); } fft_NW(u); @@ -74,7 +81,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Custom code for various WIDTH values -#if WIDTH == 256 && NW == 4 && !BCAST && CLEAN == 1 && UNROLL_W >= 3 +#if WIDTH == 256 && NW == 4 && FFT_VARIANT_W == 3 // Custom code for WIDTH=256, NW=4 @@ -104,7 +111,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Finish third tabMul and perform final fft4. finish_tabMul4_fft4(WG, lds, trig, preloads, u, 16, me, 1); -#elif WIDTH == 512 && NW == 8 && !BCAST && CLEAN == 1 && UNROLL_W >= 3 +#elif WIDTH == 512 && NW == 8 && FFT_VARIANT_W == 3 // Custom code for WIDTH=512, NW=8 @@ -128,7 +135,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Finish second tabMul and perform final fft8. finish_tabMul8_fft8(WG, lds, trig, preloads, u, 8, me, 0); // We'd rather set save_one_more_mul to 1 -#elif WIDTH == 1024 && NW == 4 && !BCAST && CLEAN == 1 && UNROLL_W >= 3 +#elif WIDTH == 1024 && NW == 4 && FFT_VARIANT_W == 3 // Custom code for WIDTH=1024, NW=4 @@ -164,7 +171,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Finish fourth tabMul and perform final fft4. finish_tabMul4_fft4(WG, lds, trig, preloads, u, 64, me, 1); -#elif WIDTH == 4096 && NW == 8 && !BCAST && CLEAN == 1 && UNROLL_W >= 3 +#elif WIDTH == 4096 && NW == 8 && FFT_VARIANT_W == 3 // Custom code for WIDTH=4K, NW=8 diff --git a/src/tune.cpp b/src/tune.cpp index 8eb1655f..3bd83815 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -161,7 +161,7 @@ void Tune::ztune() { for (FFTShape shape : configs) { double bpw[4]; double A[4]; - for (u32 variant = 0; variant < FFTConfig::N_VARIANT; ++variant) { + for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant(variant)) { FFTConfig fft{shape, variant, CARRY_AUTO}; std::tie(bpw[variant], A[variant]) = maxBpw(fft); } @@ -275,7 +275,7 @@ void Tune::tune() { // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, 0, CARRY_32}.maxExp()); - for (u32 variant = 0; variant < FFTConfig::N_VARIANT; ++variant) { + for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { vector carryToTest{CARRY_32}; // We need to test both carry-32 and carry-64 only when the carry transition is within the BPW range. if (FFTConfig{shape, variant, CARRY_64}.maxBpw() > FFTConfig{shape, variant, CARRY_32}.maxBpw()) { From b1deb06c8fad67be1f00b7036eaf13fc6b2edb26 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 8 Jun 2025 19:52:13 +0000 Subject: [PATCH 004/115] Changed FFT spec from a single-digit code (0,1,2,3) to a 3-digit WMH code. Updated maxBpw tables. Changed -tune to handle new FFT spec code. Added TABMUL_CHAIN option (it did not deserve to be a WMH code because it has little impact on Z). --- src/Args.cpp | 3 + src/FFTConfig.cpp | 79 +++++++++---- src/FFTConfig.h | 15 ++- src/Gpu.cpp | 3 +- src/cl/base.cl | 24 ++-- src/cl/fftbase.cl | 19 +-- src/cl/fftheight.cl | 10 +- src/cl/fftwidth.cl | 10 +- src/fftbpw.h | 267 +++++++++++++++--------------------------- src/tune.cpp | 277 ++++++++++++++++++++++++++++++++++---------- src/tune.h | 4 +- 11 files changed, 420 insertions(+), 291 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index d5a11064..57b66be2 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -189,6 +189,9 @@ named "config.txt" in the prpll run directory. -use PAD= : insert pad bytes to possibly improve memory access patterns. Val is number bytes to pad. -use MIDDLE_IN_LDS_TRANSPOSE=0|1 : Transpose values in local memory before writing to global memory -use MIDDLE_OUT_LDS_TRANSPOSE=0|1 : Transpose values in local memory before writing to global memory + -use TABMUL_CHAIN=: Controls how trig values are obtained in WIDTH and HEIGHT when FFT-spec is 1. + 0 = Read one trig value and compute the next 3 or 7. + 1 = All trig values are pre-computed and read from memmory. -use DEBUG : enable asserts in OpenCL kernels (slow, developers) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index d5f7cc48..ae34cfa9 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -15,24 +15,14 @@ #include #include - - -// BUG, BUG, BUG: We need to redo calculations for maxBPW for FFT variants. There are more variants now. Can we derive a formula for each variation -// as a percentage of the best variant? Ex: if variant 313 is the most accurate, dropping to 213 might be 99.92% of max, 113 might be 99.86% of max, -// 303 might be 99.75% of max, etc. Warning: the percentage change from changing the middle variant will be different for every middle size. -// This hack crudely maps new variant numbers to old variant numbers -#define vhack(x) ((variant_W(x) >= 2 ? 2 : 0) + (variant_M(x) == 1 ? 1 : 0)) - - - using namespace std; struct FftBpw { string fft; - array bpw; + array bpw; }; -map> BPW { +map> BPW { #include "fftbpw.h" }; @@ -127,6 +117,11 @@ double FFTShape::carry32BPW() const { // while the 0.5*log2() models the impact of FFT size changes. // We model carry with a Gumbel distrib similar to the one used for ROE, and measure carry with // -use STATS=1. See -carryTune + +//GW: I have no idea why this is needed. Without it, -tune fails on FFT sizes from 256K to 1M +// Perhaps it has something to do with RNDVALdoubleToLong in carryutil +if (18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())) > 19.0) return 19.0; + return 18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())); } @@ -149,21 +144,49 @@ FFTShape::FFTShape(u32 w, u32 m, u32 h) : bpw = FFTShape{h, m, w}.bpw; } else { // Make up some defaults - double d = 0.275 * (log2(size()) - log2(256 * 13 * 1024 * 2)); - bpw = {18.1-d, 18.2-d, 18.2-d, 18.3-d}; - log("BPW info for %s not found, defaults={%.2f, %.2f, %.2f, %.2f}\n", s.c_str(), bpw[0], bpw[1], bpw[2], bpw[3]); + + //double d = 0.275 * (log2(size()) - log2(256 * 13 * 1024 * 2)); + //bpw = {18.1-d, 18.2-d, 18.2-d, 18.3-d}; + //log("BPW info for %s not found, defaults={%.2f, %.2f, %.2f, %.2f}\n", s.c_str(), bpw[0], bpw[1], bpw[2], bpw[3]); + + // Manipulate the shape into something that was likely pre-computed + while (m < 9) { m *= 2; w /= 2; } + while (w >= 4*h) { w /= 2; h *= 2; } + while (w < h || w < 256 || w == 2048) { w *= 2; h /= 2; } + while (h < 256) { h *= 2; m /= 2; } + bpw = FFTShape{w, m, h}.bpw; + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) bpw[j] -= 0.05; // Assume this fft spec is worse than measured fft specs + printf("BPW info for %s not found, defaults={", s.c_str()); + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", bpw[j]); + printf("}\n"); } } } +// Return TRUE for "favored" shapes. That is, those that are most likely to be useful. To save time in generating bpw data, only these favored +// shapes have their bpw data pre-computed. Bpw for non-favored shapes is guessed from the bpw data we do have. Also. -tune will normally only +// time favored shapes. These are the rules for deciding favored shapes: +// WIDTH >= HEIGHT +// WIDTH=4K: HEIGHT>=512, MIDDLE>=9 (2*8 combos) +// WIDTH=1K: MIDDLE>=5 (3*12 combos) +// WIDTH=512: MIDDLE>=4 (2*13 combos) +// WIDTH=256: MIDDLE>=1 (16 combos) +bool FFTShape::isFavoredShape() const { + return width >= height && + ((width == 4096 && height >= 512 && middle >= 9) || + (width == 1024 && middle >= 5) || + (width == 512 && middle >= 4) || + (width == 256 && middle >= 1)); +} + FFTConfig::FFTConfig(const string& spec) { auto v = split(spec, ':'); // assert(v.size() == 1 || v.size() == 3 || v.size() == 4 || v.size() == 5); if (v.size() == 1) { - *this = {FFTShape::multiSpec(spec).front(), 3, CARRY_AUTO}; + *this = {FFTShape::multiSpec(spec).front(), LAST_VARIANT, CARRY_AUTO}; } if (v.size() == 3) { - *this = {FFTShape{v[0], v[1], v[2]}, 3, CARRY_AUTO}; + *this = {FFTShape{v[0], v[1], v[2]}, LAST_VARIANT, CARRY_AUTO}; } else if (v.size() == 4) { *this = {FFTShape{v[0], v[1], v[2]}, parseInt(v[3]), CARRY_AUTO}; } else if (v.size() == 5) { @@ -180,18 +203,30 @@ FFTConfig::FFTConfig(FFTShape shape, u32 variant, u32 carry) : variant{variant}, carry{carry} { - assert(variant_W() < N_VARIANT_W); - assert(variant_M() < N_VARIANT_M); - assert(variant_H() < N_VARIANT_H); + assert(variant_W(variant) < N_VARIANT_W); + assert(variant_M(variant) < N_VARIANT_M); + assert(variant_H(variant) < N_VARIANT_H); } string FFTConfig::spec() const { - string s = shape.spec() + ":" + to_string(variant); + string s = shape.spec() + ":" + to_string(variant_W(variant)) + to_string(variant_M(variant)) + to_string(variant_H(variant)); return carry == CARRY_AUTO ? s : (s + (carry == CARRY_32 ? ":0" : ":1")); } double FFTConfig::maxBpw() const { - double b = shape.bpw[vhack(variant)]; + double b; + // Look up the pre-computed maximum bpw. The lookup table contains data for variants 000, 101, 202, 010, 111, 212. + // For 4K width, the lookup table contains data for variants 100, 101, 202, 110, 111, 212 since BCAST only works for width <= 1024. + if (variant_W(variant) == variant_H(variant) || + (shape.width > 1024 && variant_W(variant) == 1 && variant_H(variant) == 0)) { + b = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; + } + // Interpolate for the maximum bpw. This might could be improved upon. However, I doubt people will use these variants often. + else { + double b1 = shape.bpw[variant_M(variant) * 3 + variant_W(variant)]; + double b2 = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; + b = (b1 + b2) / 2.0; + } return carry == CARRY_32 ? std::min(shape.carry32BPW(), b) : b; } diff --git a/src/FFTConfig.h b/src/FFTConfig.h index 55e41534..dc65cad0 100644 --- a/src/FFTConfig.h +++ b/src/FFTConfig.h @@ -10,6 +10,10 @@ #include #include +// We pre-calculate the maximum BPW for a number of fft specs. From these entries we can either look up or interpolate to get the +// maximum BPW for all variants of an FFT spec. The variants for which maximum bpw are precomputed are 000, 101, 202, 010, 111, 212. +#define NUM_BPW_ENTRIES 6 + class Args; // Format 'n' with a K or M suffix if multiple of 1024 or 1024*1024 @@ -30,7 +34,7 @@ class FFTShape { u32 width = 0; u32 middle = 0; u32 height = 0; - array bpw; + array bpw; FFTShape(u32 w = 1, u32 m = 1, u32 h = 1); FFTShape(const string& w, const string& m, const string& h); @@ -45,18 +49,19 @@ class FFTShape { double carry32BPW() const; bool needsLargeCarry(u32 E) const; + bool isFavoredShape() const; }; -static const u32 N_VARIANT_W = 4; +static const u32 N_VARIANT_W = 3; static const u32 N_VARIANT_M = 2; -static const u32 N_VARIANT_H = 4; +static const u32 N_VARIANT_H = 3; static const u32 LAST_VARIANT = (N_VARIANT_W - 1) * 100 + (N_VARIANT_M - 1) * 10 + N_VARIANT_H - 1; inline u32 variant_WMH(u32 v_W, u32 v_M, u32 v_H) { return v_W * 100 + v_M * 10 + v_H; } inline u32 variant_W(u32 v) { return v / 100; } inline u32 variant_M(u32 v) { return v % 100 / 10; } inline u32 variant_H(u32 v) { return v % 10; } -inline u32 next_variant(u32 v) { u32 new_v; - new_v = v + 1; if (variant_H (new_v) < N_VARIANT_H) return (new_v); +inline u32 next_variant(u32 v) { + u32 new_v = v + 1; if (variant_H (new_v) < N_VARIANT_H) return (new_v); new_v = (v / 10 + 1) * 10; if (variant_M (new_v) < N_VARIANT_M) return (new_v); new_v = (v / 100 + 1) * 100; return (new_v); } diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 0f8d5874..d72dde7f 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -227,7 +227,8 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< "MIDDLE_IN_LDS_TRANSPOSE", "MIDDLE_OUT_LDS_TRANSPOSE", "TAIL_KERNELS", - "TAIL_TRIGS" + "TAIL_TRIGS", + "TABMUL_CHAIN" }); if (!isValid) { log("Warning: unrecognized -use key '%s'\n", k.c_str()); diff --git a/src/cl/base.cl b/src/cl/base.cl index ec55308d..8190406c 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -76,26 +76,30 @@ G_H "group height" == SMALL_HEIGHT / NH #endif // FFT variant is in 3 parts. One digit for WIDTH, one digit for MIDDLE, one digit for HEIGHT. -// For WIDTH and HEIGHT there are 4 variants: +// For WIDTH and HEIGHT there are 3 variants: // 0 compute one trig, bcast, chainmul previously was :even/:odd BCAST=1 -// 1 read one trig, with old chainmul previously was :0/:1 -// 2 read all trigs, no chainmul previously was :2/:3 -// 3 read all trigs, sin/cos format for more FMA previously was :2/:3 UNROLL_W=3 -// Note smaller numbers above do more F64 and are less accurate, larger numbers have more memory accesses +// 1 if TABMUL_CHAIN, read one trig then chainmul previously was :0/:1 +// if !TABMUL_CHAIN, read all trigs, no chainmul previously was :2/:3 +// 2 read all trigs in sin/cos format for more FMA previously was :2/:3 UNROLL_W=3 +// Note: smaller numbers above do more F64 and are less accurate, larger numbers have more memory accesses and are more accurate // For MIDDLE there are two variants: -// 0 chainmul +// 0 full length chainmul // 1 lots of computing trigs, very short chainmul for maximum accuracy previously was :1/:3 #define FFT_VARIANT_W (FFT_VARIANT / 100) #define FFT_VARIANT_M (FFT_VARIANT % 100 / 10) #define FFT_VARIANT_H (FFT_VARIANT % 10) -#if FFT_VARIANT_W > 3 -#error FFT_VARIANT_W must be between 0 and 3 +#if FFT_VARIANT_W > 2 +#error FFT_VARIANT_W must be between 0 and 2 #endif #if FFT_VARIANT_M > 1 #error FFT_VARIANT_M must be between 0 and 1 #endif -#if FFT_VARIANT_H > 3 -#error FFT_VARIANT_H must be between 0 and 3 +#if FFT_VARIANT_H > 2 +#error FFT_VARIANT_H must be between 0 and 2 +#endif + +#if !defined(TABMUL_CHAIN) +#define TABMUL_CHAIN 0 #endif #if !defined(UNROLL_W) diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index bda7678b..4e5f9584 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -11,7 +11,7 @@ void chainMul4(T2 *u, T2 w) { T2 base = csqTrig(w); u[2] = cmul(u[2], base); - double a = 2 * base.y; + double a = mul2(base.y); base = U2(fma(a, -w.y, w.x), fma(a, w.x, -w.y)); u[3] = cmul(u[3], base); } @@ -52,10 +52,11 @@ void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { #else // This version of chainMul8 minimizes F64 ops even if that increases roundoff error. -// This version is faster on a Radeon 7 with worse roundoff in :0 fft spec. The :2 fft spec is even faster with no roundoff penalty. -// This version is the same speed on a TitanV due to its great F64 throughput. +// This version is faster on a Radeon 7 with worse roundoff. However, new_FFT_width is even faster with better roundoff. +// This version is the same speed on a TitanV probably due to its great F64 throughput. // This version is slower on R7Pro due to a rocm optimizer issue in double-wide single-kernel tailSquare using BCAST. I could not find a work-around. -// Other GPUs? This version might be useful. +// Other GPUs??? This version might be useful. If we decide to make this available, it will need a new width and height fft spec number. +// Consequently, an increase in the BPW table and increase work for -ztune and -tune. void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { u[1] = cmulFancy(u[1], w); @@ -182,7 +183,7 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } } -void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me, bool chainmul) { +void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { #if 0 u32 p = me / f * f; #else @@ -201,10 +202,10 @@ void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me, bool chainmul) { #endif // This code uses chained complex multiplies which could be faster on GPUs with great DP throughput or poor memory bandwidth or caching. -// This ought to be the least accurate version of Tabmul. In practice this is more accurate (at least when n==8) than reading precomputed -// values from memory. Perhaps chained Fancy muls are the reason (or was resolved when the algorithm to precompute trig values changed). +// This ought to be the least accurate version of Tabmul. In practice, this is just as accurate as reading precomputed values from memory. +// Apparently, chained Fancy muls at these short n=4 and n=8 lengths are very accurate. - if (chainmul) { + if (TABMUL_CHAIN) { T2 w = trig[p]; chainMul (n, u, w, 0); return; @@ -213,7 +214,7 @@ void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me, bool chainmul) { // Theoretically, maximum accuracy. Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. // Radeon VII loves this case, it is faster than the chainmul case. nVidia Titan V hates this case. - if (!chainmul) { + if (!TABMUL_CHAIN) { T2 w = trig[p]; if (n >= 8) { diff --git a/src/cl/fftheight.cl b/src/cl/fftheight.cl index 7199017c..2ceb9326 100644 --- a/src/cl/fftheight.cl +++ b/src/cl/fftheight.cl @@ -68,7 +68,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { if (s > 1) { bar(); } fft_NH(u); - tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me, FFT_VARIANT_H == 1); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); shufl(SMALL_HEIGHT / NH, lds, u, NH, s); } fft_NH(u); @@ -85,7 +85,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < WG; s *= NH) { if (s > 1) { bar(WG); } fft_NH(u); - tabMul(WG, trig, u, NH, s, me % WG, FFT_VARIANT_H == 1); + tabMul(WG, trig, u, NH, s, me % WG); shufl2(WG, lds, u, NH, s); } fft_NH(u); @@ -104,7 +104,7 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { // Custom code for various SMALL_HEIGHT values -#if SMALL_HEIGHT == 256 && NH == 4 && FFT_VARIANT_H == 3 +#if SMALL_HEIGHT == 256 && NH == 4 && FFT_VARIANT_H == 2 // Custom code for SMALL_HEIGHT=256, NH=4 @@ -134,7 +134,7 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { // Finish third tabMul and perform final fft4. finish_tabMul4_fft4(WG, partitioned_lds, trig, preloads, u, 16, me, 1); -#elif SMALL_HEIGHT == 512 && NH == 8 && FFT_VARIANT_H == 3 +#elif SMALL_HEIGHT == 512 && NH == 8 && FFT_VARIANT_H == 2 // Custom code for SMALL_HEIGHT=512, NH=8 @@ -158,7 +158,7 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { // Finish second tabMul and perform final fft8. finish_tabMul8_fft8(WG, partitioned_lds, trig, preloads, u, 8, me, 1); -#elif SMALL_HEIGHT == 1024 && NH == 4 && FFT_VARIANT_H == 3 +#elif SMALL_HEIGHT == 1024 && NH == 4 && FFT_VARIANT_H == 2 // Custom code for SMALL_HEIGHT=1024, NH=4 diff --git a/src/cl/fftwidth.cl b/src/cl/fftwidth.cl index d19a2345..92cfc47d 100644 --- a/src/cl/fftwidth.cl +++ b/src/cl/fftwidth.cl @@ -58,7 +58,7 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { for (u32 s = 1; s < WIDTH / NW; s *= NW) { if (s > 1) { bar(); } fft_NW(u); - tabMul(WIDTH / NW, trig, u, NW, s, me, FFT_VARIANT_W == 1); + tabMul(WIDTH / NW, trig, u, NW, s, me); shufl( WIDTH / NW, lds, u, NW, s); } fft_NW(u); @@ -81,7 +81,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Custom code for various WIDTH values -#if WIDTH == 256 && NW == 4 && FFT_VARIANT_W == 3 +#if WIDTH == 256 && NW == 4 && FFT_VARIANT_W == 2 // Custom code for WIDTH=256, NW=4 @@ -111,7 +111,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Finish third tabMul and perform final fft4. finish_tabMul4_fft4(WG, lds, trig, preloads, u, 16, me, 1); -#elif WIDTH == 512 && NW == 8 && FFT_VARIANT_W == 3 +#elif WIDTH == 512 && NW == 8 && FFT_VARIANT_W == 2 // Custom code for WIDTH=512, NW=8 @@ -135,7 +135,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Finish second tabMul and perform final fft8. finish_tabMul8_fft8(WG, lds, trig, preloads, u, 8, me, 0); // We'd rather set save_one_more_mul to 1 -#elif WIDTH == 1024 && NW == 4 && FFT_VARIANT_W == 3 +#elif WIDTH == 1024 && NW == 4 && FFT_VARIANT_W == 2 // Custom code for WIDTH=1024, NW=4 @@ -171,7 +171,7 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { // Finish fourth tabMul and perform final fft4. finish_tabMul4_fft4(WG, lds, trig, preloads, u, 64, me, 1); -#elif WIDTH == 4096 && NW == 8 && FFT_VARIANT_W == 3 +#elif WIDTH == 4096 && NW == 8 && FFT_VARIANT_W == 2 // Custom code for WIDTH=4K, NW=8 diff --git a/src/fftbpw.h b/src/fftbpw.h index d009981d..1e8712d9 100644 --- a/src/fftbpw.h +++ b/src/fftbpw.h @@ -1,174 +1,93 @@ -{ "256:2:256", {19.521, 19.521, 19.651, 19.651}}, -{ "256:3:256", {19.362, 19.363, 19.321, 19.392}}, -{ "256:4:256", {19.238, 19.249, 19.342, 19.366}}, -{ "256:2:512", {19.150, 19.150, 19.287, 19.287}}, -{ "512:2:256", {19.150, 19.150, 19.285, 19.285}}, -{ "256:5:256", {19.094, 19.149, 19.207, 19.268}}, -{ "256:6:256", {19.057, 19.114, 19.052, 19.127}}, -{ "256:3:512", {19.052, 19.058, 19.115, 19.127}}, -{ "512:3:256", {19.053, 19.054, 19.098, 19.104}}, -{ "256:7:256", {18.990, 19.047, 18.995, 19.055}}, -{ "1K:2:256", {18.991, 18.991, 19.133, 19.133}}, -{ "256:8:256", {18.910, 18.975, 18.996, 19.118}}, -{ "256:4:512", {18.877, 18.901, 19.011, 19.038}}, -{ "256:2:1K", {18.971, 18.971, 19.135, 19.135}}, -{ "512:4:256", {18.875, 18.899, 19.015, 19.036}}, -{ "512:2:512", {18.796, 18.796, 18.944, 18.944}}, -{ "256:9:256", {18.844, 18.962, 18.852, 18.968}}, -{"256:10:256", {18.777, 18.887, 18.890, 19.025}}, -{ "256:5:512", {18.729, 18.768, 18.867, 18.930}}, -{ "512:5:256", {18.721, 18.758, 18.859, 18.926}}, -{"256:11:256", {18.721, 18.888, 18.814, 18.993}}, -{ "1K:3:256", {18.830, 18.827, 18.849, 18.869}}, -{"256:12:256", {18.729, 18.834, 18.753, 18.842}}, -{ "256:6:512", {18.741, 18.795, 18.797, 18.844}}, -{ "256:3:1K", {18.831, 18.827, 18.855, 18.860}}, -{ "512:6:256", {18.756, 18.777, 18.786, 18.842}}, -{ "512:3:512", {18.716, 18.723, 18.799, 18.812}}, -{"256:13:256", {18.672, 18.773, 18.760, 18.914}}, -{"256:14:256", {18.639, 18.795, 18.669, 18.747}}, -{ "256:7:512", {18.685, 18.730, 18.705, 18.767}}, -{ "512:7:256", {18.670, 18.734, 18.697, 18.768}}, -{"256:15:256", {18.610, 18.753, 18.656, 18.815}}, -{"256:16:256", {18.552, 18.665, 18.656, 18.835}}, -{ "1K:4:256", {18.694, 18.711, 18.816, 18.843}}, -{ "1K:2:512", {18.627, 18.627, 18.765, 18.765}}, -{ "256:8:512", {18.558, 18.620, 18.658, 18.740}}, -{ "256:4:1K", {18.682, 18.696, 18.814, 18.839}}, -{ "512:8:256", {18.570, 18.606, 18.643, 18.725}}, -{ "512:4:512", {18.510, 18.539, 18.647, 18.670}}, -{ "512:2:1K", {18.617, 18.617, 18.756, 18.756}}, -{ "256:9:512", {18.550, 18.637, 18.590, 18.693}}, -{ "512:9:256", {18.554, 18.647, 18.592, 18.672}}, -{ "1K:5:256", {18.549, 18.591, 18.676, 18.735}}, -{"256:10:512", {18.458, 18.523, 18.537, 18.654}}, -{ "256:5:1K", {18.537, 18.590, 18.676, 18.726}}, -{"512:10:256", {18.473, 18.530, 18.550, 18.652}}, -{ "512:5:512", {18.408, 18.377, 18.523, 18.550}}, -{"256:11:512", {18.409, 18.514, 18.481, 18.636}}, -{"512:11:256", {18.402, 18.519, 18.479, 18.638}}, -{ "1K:6:256", {18.515, 18.568, 18.542, 18.560}}, -{ "1K:3:512", {18.520, 18.522, 18.582, 18.595}}, -{"256:12:512", {18.457, 18.543, 18.474, 18.584}}, -{ "256:6:1K", {18.521, 18.568, 18.554, 18.592}}, -{"512:12:256", {18.447, 18.545, 18.479, 18.572}}, -{ "512:6:512", {18.416, 18.435, 18.489, 18.550}}, -{ "512:3:1K", {18.516, 18.515, 18.571, 18.584}}, -{"256:13:512", {18.328, 18.413, 18.453, 18.548}}, -{"512:13:256", {18.326, 18.419, 18.440, 18.536}}, -{ "1K:7:256", {18.448, 18.508, 18.481, 18.531}}, -{"256:14:512", {18.342, 18.507, 18.399, 18.519}}, -{ "256:7:1K", {18.453, 18.501, 18.464, 18.522}}, -{"512:14:256", {18.310, 18.504, 18.379, 18.531}}, -{ "512:7:512", {18.336, 18.375, 18.426, 18.480}}, -{"256:15:512", {18.287, 18.404, 18.377, 18.516}}, -{"512:15:256", {18.292, 18.417, 18.365, 18.522}}, -{"512:16:256", {18.282, 18.367, 18.300, 18.507}}, -{ "1K:8:256", {18.360, 18.438, 18.470, 18.574}}, -{ "1K:4:512", {18.330, 18.373, 18.479, 18.503}}, -{ "1K:2:1K", {18.440, 18.440, 18.581, 18.581}}, -{ "256:8:1K", {18.349, 18.419, 18.470, 18.577}}, -{ "512:8:512", {18.196, 18.253, 18.316, 18.393}}, -{ "512:4:1K", {18.319, 18.337, 18.468, 18.499}}, -{ "4K:2:256", {18.285, 18.285, 18.461, 18.461}}, -{ "1K:9:256", {18.313, 18.425, 18.331, 18.455}}, -{ "256:9:1K", {18.323, 18.412, 18.348, 18.450}}, -{ "512:9:512", {18.208, 18.275, 18.300, 18.404}}, -{ "1K:10:256", {18.241, 18.326, 18.359, 18.493}}, -{ "1K:5:512", {18.189, 18.231, 18.334, 18.372}}, -{ "256:10:1K", {18.230, 18.323, 18.323, 18.486}}, -{"512:10:512", {18.103, 18.128, 18.193, 18.292}}, -{ "512:5:1K", {18.203, 18.218, 18.312, 18.360}}, -{ "1K:11:256", {18.189, 18.332, 18.264, 18.461}}, -{ "256:11:1K", {18.194, 18.326, 18.262, 18.449}}, -{"512:11:512", {18.081, 18.165, 18.179, 18.300}}, -{ "1K:12:256", {18.202, 18.312, 18.218, 18.330}}, -{ "1K:6:512", {18.219, 18.252, 18.271, 18.311}}, -{ "1K:3:1K", {18.287, 18.297, 18.328, 18.334}}, -{ "256:12:1K", {18.206, 18.314, 18.221, 18.315}}, -{"512:12:512", {18.120, 18.177, 18.187, 18.279}}, -{ "512:6:1K", {18.207, 18.255, 18.273, 18.313}}, -{ "4K:3:256", {18.208, 18.206, 18.315, 18.311}}, -{ "1K:13:256", {18.154, 18.232, 18.238, 18.383}}, -{ "256:13:1K", {18.165, 18.228, 18.241, 18.375}}, -{"512:13:512", {18.017, 18.071, 18.121, 18.216}}, -{ "1K:14:256", {18.118, 18.249, 18.140, 18.265}}, -{ "1K:7:512", {18.162, 18.208, 18.186, 18.244}}, -{ "256:14:1K", {18.134, 18.244, 18.150, 18.216}}, -{"512:14:512", {18.027, 18.123, 18.122, 18.229}}, -{ "512:7:1K", {18.151, 18.205, 18.198, 18.252}}, -{ "1K:15:256", {18.099, 18.220, 18.137, 18.274}}, -{ "1K:16:256", {18.025, 18.128, 18.120, 18.286}}, -{ "256:15:1K", {18.104, 18.209, 18.152, 18.277}}, -{"512:15:512", {17.973, 18.074, 18.075, 18.183}}, -{ "1K:8:512", {18.015, 18.090, 18.148, 18.222}}, -{ "1K:4:1K", {18.161, 18.180, 18.249, 18.302}}, -{ "512:8:1K", {18.029, 18.077, 18.136, 18.212}}, -{ "4K:4:256", {18.043, 18.061, 18.189, 18.200}}, -{ "4K:2:512", {18.002, 18.002, 18.118, 18.118}}, -{ "1K:9:512", {18.027, 18.118, 18.065, 18.165}}, -{ "512:9:1K", {18.016, 18.104, 18.083, 18.170}}, -{ "1K:10:512", {17.909, 17.978, 18.022, 18.122}}, -{ "1K:5:1K", {18.011, 18.055, 18.133, 18.213}}, -{ "512:10:1K", {17.905, 17.972, 18.024, 18.112}}, -{ "4K:5:256", {17.914, 17.915, 18.018, 18.080}}, -{ "1K:11:512", {17.873, 17.998, 17.962, 18.128}}, -{ "512:11:1K", {17.871, 17.973, 17.977, 18.114}}, -{ "1K:12:512", {17.929, 18.017, 17.961, 18.060}}, -{ "1K:6:1K", {17.992, 18.040, 18.003, 18.056}}, -{ "512:12:1K", {17.929, 18.013, 17.963, 18.057}}, -{ "4K:6:256", {17.922, 17.954, 17.995, 18.056}}, -{ "4K:3:512", {17.864, 17.868, 18.005, 18.005}}, -{ "1K:13:512", {17.806, 17.883, 17.932, 18.020}}, -{ "512:13:1K", {17.797, 17.883, 17.934, 18.013}}, -{ "1K:14:512", {17.815, 17.958, 17.882, 18.002}}, -{ "1K:7:1K", {17.925, 17.973, 17.947, 18.005}}, -{ "512:14:1K", {17.822, 17.952, 17.868, 18.003}}, -{ "4K:7:256", {17.846, 17.897, 17.931, 18.002}}, -{ "1K:15:512", {17.763, 17.884, 17.860, 17.996}}, -{ "512:15:1K", {17.773, 17.884, 17.863, 17.996}}, -{ "512:16:1K", {17.675, 17.820, 17.787, 17.976}}, -{ "1K:8:1K", {17.806, 17.881, 17.932, 18.043}}, -{ "4K:8:256", {17.715, 17.744, 17.821, 17.897}}, -{ "4K:4:512", {17.678, 17.699, 17.819, 17.840}}, -{ "4K:2:1K", {17.765, 17.765, 17.941, 17.941}}, -{ "1K:9:1K", {17.792, 17.892, 17.831, 17.941}}, -{ "4K:9:256", {17.721, 17.789, 17.810, 17.889}}, -{ "1K:10:1K", {17.699, 17.788, 17.797, 17.951}}, -{ "4K:10:256", {17.599, 17.679, 17.710, 17.786}}, -{ "4K:5:512", {17.560, 17.597, 17.694, 17.716}}, -{ "1K:11:1K", {17.669, 17.776, 17.732, 17.927}}, -{ "4K:11:256", {17.564, 17.678, 17.671, 17.811}}, -{ "1K:12:1K", {17.691, 17.772, 17.713, 17.800}}, -{ "4K:12:256", {17.603, 17.694, 17.687, 17.788}}, -{ "4K:6:512", {17.561, 17.613, 17.696, 17.736}}, -{ "4K:3:1K", {17.674, 17.678, 17.766, 17.778}}, -{ "1K:13:1K", {17.613, 17.688, 17.716, 17.836}}, -{ "4K:13:256", {17.515, 17.583, 17.624, 17.685}}, -{ "1K:14:1K", {17.596, 17.716, 17.626, 17.741}}, -{ "4K:14:256", {17.505, 17.636, 17.579, 17.713}}, -{ "4K:7:512", {17.528, 17.547, 17.631, 17.683}}, -{ "1K:15:1K", {17.566, 17.672, 17.624, 17.747}}, -{ "1K:16:1K", {17.472, 17.617, 17.573, 17.755}}, -{ "4K:15:256", {17.478, 17.566, 17.569, 17.693}}, -{ "4K:16:256", {17.450, 17.538, 17.519, 17.652}}, -{ "4K:8:512", {17.385, 17.424, 17.497, 17.555}}, -{ "4K:4:1K", {17.494, 17.526, 17.626, 17.646}}, -{ "4K:9:512", {17.400, 17.465, 17.500, 17.591}}, -{ "4K:10:512", {17.263, 17.350, 17.374, 17.473}}, -{ "4K:5:1K", {17.365, 17.398, 17.499, 17.534}}, -{ "4K:11:512", {17.245, 17.323, 17.366, 17.473}}, -{ "4K:12:512", {17.279, 17.341, 17.400, 17.488}}, -{ "4K:6:1K", {17.369, 17.416, 17.466, 17.519}}, -{ "4K:13:512", {17.161, 17.236, 17.287, 17.352}}, -{ "4K:14:512", {17.180, 17.264, 17.296, 17.431}}, -{ "4K:7:1K", {17.302, 17.355, 17.391, 17.445}}, -{ "4K:15:512", {17.141, 17.228, 17.255, 17.376}}, -{ "4K:8:1K", {17.174, 17.178, 17.292, 17.360}}, -{ "4K:9:1K", {17.185, 17.259, 17.280, 17.361}}, -{ "4K:10:1K", {17.091, 17.150, 17.185, 17.227}}, -{ "4K:11:1K", {17.070, 17.144, 17.154, 17.260}}, -{ "4K:12:1K", {17.099, 17.158, 17.180, 17.250}}, -{ "4K:13:1K", {16.988, 17.039, 17.107, 17.156}}, -{ "4K:14:1K", {17.023, 17.107, 17.094, 17.205}}, -{ "4K:15:1K", {16.949, 17.050, 17.072, 17.164}}, +{ "256:2:256", {19.204, 19.547, 19.636, 19.204, 19.547, 19.636}}, +{ "256:3:256", {19.106, 19.386, 19.369, 19.106, 19.399, 19.361}}, +{ "256:4:256", {18.928, 19.236, 19.322, 18.954, 19.272, 19.367}}, +{ "256:5:256", {19.093, 19.094, 19.242, 19.147, 19.142, 19.314}}, +{ "256:6:256", {18.805, 19.065, 19.005, 18.854, 19.134, 19.101}}, +{ "256:7:256", {18.688, 18.995, 18.971, 18.763, 19.068, 19.071}}, +{ "256:8:256", {18.600, 18.909, 18.964, 18.679, 19.018, 19.120}}, +{ "512:4:256", {18.729, 18.913, 19.062, 18.770, 18.963, 19.120}}, +{ "256:9:256", {18.634, 18.871, 18.802, 18.688, 18.981, 18.961}}, +{"256:10:256", {18.748, 18.770, 18.906, 18.895, 18.896, 19.051}}, +{ "512:5:256", {18.751, 18.766, 18.891, 18.812, 18.832, 18.966}}, +{"256:11:256", {18.523, 18.782, 18.791, 18.594, 18.910, 18.946}}, +{"256:12:256", {18.558, 18.749, 18.669, 18.612, 18.880, 18.842}}, +{ "512:6:256", {18.641, 18.748, 18.838, 18.686, 18.809, 18.949}}, +{"256:13:256", {18.423, 18.693, 18.794, 18.497, 18.820, 18.938}}, +{"256:14:256", {18.450, 18.671, 18.639, 18.519, 18.808, 18.785}}, +{ "512:7:256", {18.547, 18.671, 18.782, 18.629, 18.738, 18.895}}, +{"256:15:256", {18.666, 18.652, 18.628, 18.798, 18.784, 18.791}}, +{"256:16:256", {18.277, 18.568, 18.649, 18.425, 18.741, 18.849}}, +{ "512:8:256", {18.455, 18.592, 18.682, 18.508, 18.673, 18.835}}, +{ "512:4:512", {18.565, 18.599, 18.642, 18.594, 18.629, 18.716}}, +{ "512:9:256", {18.491, 18.579, 18.648, 18.546, 18.661, 18.792}}, +{ "1K:5:256", {18.553, 18.547, 18.713, 18.622, 18.601, 18.762}}, +{"512:10:256", {18.461, 18.479, 18.570, 18.561, 18.572, 18.697}}, +{ "512:5:512", {18.450, 18.473, 18.513, 18.499, 18.513, 18.583}}, +{"512:11:256", {18.335, 18.481, 18.586, 18.444, 18.579, 18.744}}, +{ "1K:6:256", {18.294, 18.528, 18.476, 18.335, 18.596, 18.562}}, +{"512:12:256", {18.381, 18.462, 18.537, 18.448, 18.560, 18.678}}, +{ "512:6:512", {18.440, 18.462, 18.585, 18.479, 18.504, 18.644}}, +{"512:13:256", {18.247, 18.401, 18.482, 18.348, 18.497, 18.645}}, +{ "1K:7:256", {18.171, 18.455, 18.442, 18.238, 18.531, 18.521}}, +{"512:14:256", {18.266, 18.376, 18.467, 18.374, 18.490, 18.637}}, +{ "512:7:512", {18.372, 18.372, 18.462, 18.428, 18.432, 18.567}}, +{"512:15:256", {18.351, 18.355, 18.460, 18.453, 18.465, 18.619}}, +{ "1K:8:256", {18.093, 18.359, 18.435, 18.156, 18.461, 18.570}}, +{"512:16:256", {18.108, 18.260, 18.298, 18.223, 18.420, 18.570}}, +{ "512:8:512", {18.256, 18.280, 18.314, 18.319, 18.369, 18.444}}, +{ "1K:9:256", {18.123, 18.328, 18.245, 18.171, 18.443, 18.426}}, +{ "512:9:512", {18.243, 18.262, 18.343, 18.315, 18.342, 18.493}}, +{ "1K:10:256", {18.233, 18.228, 18.370, 18.357, 18.343, 18.516}}, +{ "1K:5:512", {18.237, 18.232, 18.385, 18.281, 18.272, 18.479}}, +{"512:10:512", {18.142, 18.160, 18.174, 18.226, 18.243, 18.314}}, +{ "1K:11:256", {18.001, 18.234, 18.243, 18.084, 18.362, 18.411}}, +{"512:11:512", {18.156, 18.170, 18.196, 18.242, 18.251, 18.356}}, +{ "1K:12:256", {18.052, 18.225, 18.151, 18.090, 18.332, 18.291}}, +{ "1K:6:512", {18.125, 18.227, 18.293, 18.160, 18.283, 18.401}}, +{"512:12:512", {18.141, 18.163, 18.245, 18.217, 18.237, 18.409}}, +{ "1K:13:256", {17.903, 18.155, 18.249, 17.980, 18.261, 18.380}}, +{"512:13:512", {18.097, 18.104, 18.135, 18.168, 18.171, 18.239}}, +{ "1K:14:256", {17.929, 18.143, 18.103, 18.005, 18.268, 18.244}}, +{ "1K:7:512", {18.043, 18.150, 18.237, 18.100, 18.208, 18.347}}, +{"512:14:512", {18.094, 18.091, 18.154, 18.171, 18.173, 18.311}}, +{ "1K:15:256", {18.131, 18.127, 18.115, 18.260, 18.236, 18.238}}, +{"512:15:512", {18.049, 18.061, 18.112, 18.130, 18.136, 18.236}}, +{ "1K:16:256", {17.752, 18.019, 18.128, 17.892, 18.183, 18.326}}, +{ "1K:8:512", {17.930, 18.069, 18.190, 17.982, 18.137, 18.333}}, +{"512:16:512", {17.971, 17.997, 17.976, 18.072, 18.100, 18.170}}, +{ "1K:9:512", {17.944, 18.049, 18.124, 18.012, 18.122, 18.247}}, +{ "1K:10:512", {17.940, 17.939, 18.044, 18.044, 18.043, 18.195}}, +{ "1K:5:1K", {18.033, 18.016, 18.180, 18.101, 18.072, 18.237}}, +{ "1K:11:512", {17.829, 17.956, 18.078, 17.922, 18.044, 18.219}}, +{ "1K:12:512", {17.829, 17.944, 18.002, 17.912, 18.034, 18.139}}, +{ "1K:6:1K", {17.741, 18.003, 17.931, 17.794, 18.065, 18.051}}, +{ "1K:13:512", {17.744, 17.869, 18.010, 17.822, 17.959, 18.133}}, +{ "1K:14:512", {17.748, 17.854, 17.933, 17.848, 17.968, 18.100}}, +{ "1K:7:1K", {17.647, 17.914, 17.865, 17.714, 17.995, 17.994}}, +{ "1K:15:512", {17.828, 17.824, 17.943, 17.926, 17.934, 18.100}}, +{ "1K:16:512", {17.613, 17.744, 17.813, 17.711, 17.875, 18.073}}, +{ "1K:8:1K", {17.572, 17.810, 17.914, 17.628, 17.923, 18.046}}, +{ "1K:9:1K", {17.600, 17.807, 17.734, 17.636, 17.908, 17.892}}, +{ "1K:10:1K", {17.709, 17.681, 17.841, 17.823, 17.795, 17.981}}, +{ "1K:11:1K", {17.458, 17.682, 17.712, 17.560, 17.811, 17.891}}, +{ "1K:12:1K", {17.474, 17.698, 17.635, 17.560, 17.807, 17.752}}, +{ "1K:13:1K", {17.388, 17.619, 17.706, 17.465, 17.716, 17.853}}, +{ "1K:14:1K", {17.417, 17.627, 17.582, 17.491, 17.734, 17.715}}, +{ "1K:15:1K", {17.612, 17.592, 17.582, 17.720, 17.697, 17.722}}, +{ "1K:16:1K", {17.220, 17.471, 17.615, 17.369, 17.641, 17.778}}, +{ "4K:9:512", {17.457, 17.468, 17.550, 17.519, 17.531, 17.648}}, +{ "4K:10:512", {17.336, 17.336, 17.364, 17.416, 17.433, 17.507}}, +{ "4K:11:512", {17.351, 17.356, 17.393, 17.437, 17.440, 17.548}}, +{ "4K:12:512", {17.351, 17.362, 17.447, 17.420, 17.437, 17.576}}, +{ "4K:13:512", {17.278, 17.271, 17.324, 17.349, 17.351, 17.441}}, +{ "4K:14:512", {17.267, 17.270, 17.342, 17.359, 17.359, 17.503}}, +{ "4K:15:512", {17.238, 17.239, 17.295, 17.305, 17.315, 17.432}}, +{ "4K:16:512", {17.149, 17.163, 17.143, 17.251, 17.271, 17.352}}, +{ "4K:9:1K", {17.130, 17.225, 17.346, 17.189, 17.291, 17.485}}, +{ "4K:10:1K", {17.110, 17.111, 17.209, 17.199, 17.188, 17.351}}, +{ "4K:11:1K", {16.993, 17.108, 17.214, 17.084, 17.196, 17.405}}, +{ "4K:12:1K", {17.006, 17.123, 17.219, 17.101, 17.201, 17.370}}, +{ "4K:13:1K", {16.932, 17.045, 17.154, 17.002, 17.116, 17.298}}, +{ "4K:14:1K", {16.942, 17.055, 17.160, 17.027, 17.127, 17.306}}, +{ "4K:15:1K", {17.021, 17.007, 17.137, 17.104, 17.087, 17.282}}, +{ "4K:16:1K", {16.744, 16.887, 16.966, 16.921, 17.048, 17.208}}, diff --git a/src/tune.cpp b/src/tune.cpp index 3bd83815..2fb622e2 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -101,74 +101,136 @@ string formatConfigResults(const vector& results) { return s; } -pair linearFit(double* z, double STEP) { - double A = (2 * (z[3] - z[0]) + (z[2] - z[1])) / (10 * STEP); - double B = (z[0] + z[1] + z[2] + z[3]) / 4; - return {A, B}; -} - } // namespace -pair Tune::maxBpw(FFTConfig fft) { - const double STEP = 0.06; +double Tune::maxBpw(FFTConfig fft) { - double oldBpw = fft.maxBpw(); - double bpw = oldBpw; +// double bpw = oldBpw; const double TARGET = 28; - - double z[4]{}; - - z[1] = zForBpw(bpw - STEP, fft); - if (z[1] < TARGET) { - log("step down from %f\n", bpw); - bpw -= 2*STEP; - z[2] = z[1]; - z[1] = zForBpw(bpw - STEP, fft); - } else { - z[2] = zForBpw(bpw + STEP, fft); - if (z[2] > TARGET) { - log("step up from %f\n", bpw); - bpw += 2*STEP; - z[1] = z[2]; - z[2] = zForBpw(bpw + STEP, fft); - } + const u32 sample_size = 5; + + // Estimate how much bpw needs to change to increase/decrease Z by 1. + // This doesn't need to be a very accurate estimate. + // This estimate comes from analyzing a 4M FFT and a 7.5M FFT. + // The 4M FFT needed a .015 step, the 7.5M FFT needed a .012 step. + double bpw_step = .015 + (log2(fft.size()) - log2(4.0*1024*1024)) / (log2(7.5*1024*1024) - log2(4.0*1024*1024)) * (.012 - .015); + + // Pick a bpw that might be close to Z=34, it is best to err on the high side of Z=34 + double bpw1 = fft.maxBpw() - 9 * bpw_step; // Old bpw gave Z=28, we want Z=34 (or more) + +// The code below was used when building the maxBpw table from scratch +// u32 non_best_width = N_VARIANT_W - 1 - variant_W(fft.variant); // Number of notches below best-Z width variant +// u32 non_best_middle = N_VARIANT_M - 1 - variant_M(fft.variant); // Number of notches below best-Z middle variant +// double bpw1 = 18.3 - 0.275 * (log2(fft.size()) - log2(256 * 13 * 1024 * 2)) - // Default max bpw from an old gpuowl version +// 9 * bpw_step - // Default above should give Z=28, we want Z=34 (or more) +// (.08/.012 * bpw_step) * non_best_width - // 7.5M FFT has ~.08 bpw difference for each width variant below best variant +// (.06 + .04 * (fft.shape.middle - 4) / 11) * non_best_middle; // Assume .1 bpw difference MIDDLE=15 and .06 for MIDDLE=4 +//Above fails for FFTs below 512K. Perhaps we should ditch the above and read from the existing fftbpw.h data to get our starting guess. +//if (fft.size() < 512000) bpw1 = 19, bpw_step = .02; + + // Fine tune our estimate for Z=34 + double z1 = zForBpw(bpw1, fft, 1); +printf ("Guess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bpw1, z1); + while (z1 < 31.0 || z1 > 37.0) { + double prev_bpw1 = bpw1; + double prev_z1 = z1; + bpw1 = bpw1 + (z1 - 34) * bpw_step; + z1 = zForBpw(bpw1, fft, 1); +printf ("Reguess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bpw1, z1); + bpw_step = - (bpw1 - prev_bpw1) / (z1 - prev_z1); + if (bpw_step < 0.005) bpw_step = 0.005; + if (bpw_step > 0.025) bpw_step = 0.025; } - z[0] = zForBpw(bpw - 2 * STEP, fft); - z[3] = zForBpw(bpw + 2 * STEP, fft); - auto [A, B] = linearFit(z, STEP); + // Get more samples for this bpw -- average in the sample we already have + z1 = (z1 + (sample_size - 1) * zForBpw(bpw1, fft, sample_size - 1)) / sample_size; + + // Pick a bpw somewhere near Z=22 then fine tune the guess + double bpw2 = bpw1 + (z1 - 22) * bpw_step; + double z2 = zForBpw(bpw2, fft, 1); +printf ("Guess bpw for %s is %.2f first Z22 is %.2f\n", fft.spec().c_str(), bpw2, z2); + while (z2 < 20.0 || z2 > 25.0) { + double prev_bpw2 = bpw2; + double prev_z2 = z2; +// bool error_recovery = (z2 <= 0.0); +// if (error_recovery) bpw2 -= bpw_step; else + bpw2 = bpw2 + (z2 - 21) * bpw_step; + z2 = zForBpw(bpw2, fft, 1); +printf ("Reguess bpw for %s is %.2f first Z22 is %.2f\n", fft.spec().c_str(), bpw2, z2); +// if (error_recovery) { if (z2 >= 20.0) break; else continue; } + bpw_step = - (bpw2 - prev_bpw2) / (z2 - prev_z2); + if (bpw_step < 0.005) bpw_step = 0.005; + if (bpw_step > 0.025) bpw_step = 0.025; + } - double x = bpw + (TARGET - B) / A; + // Get more samples for this bpw -- average in the sample we already have + z2 = (z2 + (sample_size - 1) * zForBpw(bpw2, fft, sample_size - 1)) / sample_size; - log("%s %.3f -> %.3f | %.2f %.2f %.2f %.2f | %.0f %.1f\n", - fft.spec().c_str(), bpw, x, z[0], z[1], z[2], z[3], -A, B); - return {x, -A}; + // Interpolate for the TARGET Z value + return bpw2 + (bpw1 - bpw2) * (TARGET - z2) / (z1 - z2); } -double Tune::zForBpw(double bpw, FFTConfig fft) { - u32 exponent = primes.nearestPrime(fft.size() * bpw + 0.5); - auto [ok, res, roeSq, roeMul] = Gpu::make(q, exponent, shared, fft, {}, false)->measureROE(true); - double z = roeSq.z(); - if (!ok) { log("Error at bpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); } - return z; +double Tune::zForBpw(double bpw, FFTConfig fft, u32 count) { + u32 exponent = (count == 1) ? primes.prevPrime(fft.size() * bpw) : primes.nextPrime(fft.size() * bpw); + double total_z = 0.0; + for (u32 i = 0; i < count; i++, exponent = primes.nextPrime (exponent + 1)) { + auto [ok, res, roeSq, roeMul] = Gpu::make(q, exponent, shared, fft, {}, false)->measureROE(true); + double z = roeSq.z(); + total_z += z; +log("Zforbpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); + if (!ok) { log("Error at bpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); continue; } + } +//printf ("Out zForBpw %s %.2f avg %.2f\n", fft.spec().c_str(), bpw, total_z / count); + return total_z / count; } void Tune::ztune() { File ztune = File::openAppend("ztune.txt"); ztune.printf("\n// %s\n\n", shortTimeStr().c_str()); + + // Study a specific shape and variant + if (0) { + FFTShape shape = FFTShape(512, 15, 512); + u32 variant = 202; + u32 sample_size = 5; + FFTConfig fft{shape, variant, CARRY_AUTO}; + for (double bpw = 18.18; bpw < 18.305; bpw += 0.02) { + double z = zForBpw(bpw, fft, sample_size); + log ("Avg zForBpw %s %.2f %.2f\n", fft.spec().c_str(), bpw, z); + } + } + + // Generate a decent-sized sample that correlates bpw and Z in a range that is close to the target Z value of 28. + // For no particularly good reason, I strive to find the bpw for Z values near 35 and 21. + // Over this narrow Z range, linear curve fit should work well. The Z data is noisy, so more samples is better. + auto configs = FFTShape::multiSpec(shared.args->fftSpec); for (FFTShape shape : configs) { - double bpw[4]; - double A[4]; - for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant(variant)) { - FFTConfig fft{shape, variant, CARRY_AUTO}; - std::tie(bpw[variant], A[variant]) = maxBpw(fft); + + // 4K widths store data on variants 100, 101, 202, 110, 111, 212 + u32 bpw_variants[NUM_BPW_ENTRIES] = {000, 101, 202, 10, 111, 212}; + if (shape.width > 1024) bpw_variants[0] = 100, bpw_variants[3] = 110; + + // Copy the existing bpw array (in case we're replacing only some of the entries) + array bpw; + bpw = shape.bpw; + + // Not all shapes have their maximum bpw per-computed. But one can work on a non-favored shape by specifying it on the command line. + if (configs.size() > 1) { + if (!shape.isFavoredShape()) { log ("Skipping %s\n", shape.spec().c_str()); continue; } + } + + // Test specific variants needed for the maximum bpw table in fftbpw.h + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) { + FFTConfig fft{shape, bpw_variants[j], CARRY_AUTO}; + bpw[j] = maxBpw(fft); } string s = "\""s + shape.spec() + "\""; - ztune.printf("{%12s, {%.3f, %.3f, %.3f, %.3f}},\n", s.c_str(), bpw[0], bpw[1], bpw[2], bpw[3]); - // ztune.printf("{%12s, {%.3f, %.3f, %.3f, %.3f}, {%.0f, %.0f, %.0f, %.0f}},\n", - // s.c_str(), bpw[0], bpw[1], bpw[2], bpw[3], A[0], A[1], A[2], A[3]); +// ztune.printf("{%12s, {%.3f, %.3f, %.3f, %.3f, %.3f, %.3f}},\n", s.c_str(), bpw[0], bpw[1], bpw[2], bpw[3], bpw[4], bpw[5]); + ztune.printf("{%12s, {", s.c_str()); + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) ztune.printf("%s%.3f", j ? ", " : "", bpw[j]); + ztune.printf("}},\n"); } } @@ -178,7 +240,7 @@ void Tune::carryTune() { shared.args->flags["STATS"] = "1"; u32 prevSize = 0; for (FFTShape shape : FFTShape::multiSpec(shared.args->fftSpec)) { - FFTConfig fft{shape, 3, CARRY_AUTO}; + FFTConfig fft{shape, LAST_VARIANT, CARRY_AUTO}; if (prevSize == fft.size()) { continue; } prevSize = fft.size(); @@ -267,15 +329,118 @@ void Tune::tune() { Args *args = shared.args; string fftSpec = args->fftSpec; +//GW: detail all the configs we should auto-time first + + // Flags that prune the amount of shapes and variants to time. + // These should be computed automatically and saved in the tune.txt or config.txt file. + // Tune.txt file should have a version number. + + // A command line option to run more combinations (higher number skips more combos) + int skip_some_WH_variants = 1; // 0 = skip nothing, 1 = skip slower widths/heights unless they have better Z, 2 = only run fastest widths/heights + + // The width = height = 512 FFT shape is so good, we probably don't need to time the width = 1024, height = 256 shape. + bool skip_1K_256 = 1; + + // There are some variands only AMD GPUs can execute + bool AMDGPU = isAmdGpu(q->context->deviceId()); + +// make command line args for this? +skip_some_WH_variants = 2; +skip_1K_256 = 0; + +//GW: Suggest tuning with TAIL_KERNELS=2 even if production runs use TAIL_KERNELS=3 + + // For each width, time the 001, 101, and 201 variants to find the fastest width variant. + // In an ideal world we'd use the -time feature and look at the kCarryFused timing. Then we'd save this info in config.txt or tune.txt. + map fastest_width_variants; + + // For each height, time the 100, 101, and 102 variants to find the fastest height variant. + // In an ideal world we'd use the -time feature and look at the tailSquare timing. Then we'd save this info in config.txt or tune.txt. + map fastest_height_variants; + vector results = TuneEntry::readTuneFile(*args); + vector shapes = FFTShape::multiSpec(args->fftSpec); - for (const FFTShape& shape : FFTShape::multiSpec(args->fftSpec)) { - double minCost = -1; + // Loop through all possible FFT shapes + for (const FFTShape& shape : shapes) { // Time an exponent that's good for all variants and carry-config. - u32 exponent = primes.prevPrime(FFTConfig{shape, 0, CARRY_32}.maxExp()); + u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp()); + // Loop through all possible variants for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { + + // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. + if (variant_W(variant) == 0) { + if (!AMDGPU) continue; + if (shape.width > 1024) continue; + } + + // Only AMD GPUs support variant zero (BCAST) and only if height <= 1024. + if (variant_H(variant) == 0) { + if (!AMDGPU) continue; + if (shape.height > 1024) continue; + } + + // If only one shape was specified on the command line, time it. This lets the user time any shape, including non-favored ones. + if (shapes.size() > 1) { + + // Skip less-favored shapes + if (!shape.isFavoredShape()) continue; + + // Skip width = 1K, height = 256 + if (shape.width == 1024 && shape.height == 256 && skip_1K_256) continue; + + // Skip variants where width or height are not using the fastest variant. + // NOTE: We ought to offer a tune=option where we also test more accurate variants to extend the FFT's max exponent. + if (skip_some_WH_variants) { + u32 fastest_width = 1; + if (auto it = fastest_width_variants.find(shape.width); it != fastest_width_variants.end()) { + fastest_width = it->second; + } else { + FFTShape test = FFTShape(shape.width, 12, 256); + double cost, min_cost = -1.0; + for (u32 w = 0; w < N_VARIANT_W; w++) { + if (w == 0 && !AMDGPU) continue; + if (w == 0 && test.width > 1024) continue; + FFTConfig fft{test, variant_WMH (w, 0, 1), CARRY_32}; + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(); + log("Fast width search %6.1f %12s\n", cost, fft.spec().c_str()); + if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_width = w; } + } + fastest_width_variants[shape.width] = fastest_width; + } + if (skip_some_WH_variants == 2 && variant_W(variant) != fastest_width) continue; + if (skip_some_WH_variants == 1 && + FFTConfig{shape, variant, CARRY_32}.maxBpw() < + FFTConfig{shape, variant_WMH (fastest_width, variant_M(variant), variant_H(variant)), CARRY_32}.maxBpw()) continue; + } + if (skip_some_WH_variants) { + u32 fastest_height = 1; + if (auto it = fastest_height_variants.find(shape.height); it != fastest_height_variants.end()) { + fastest_height = it->second; + } else { + FFTShape test = FFTShape(shape.height, 12, shape.height); + double cost, min_cost = -1.0; + for (u32 h = 0; h < N_VARIANT_H; h++) { + if (h == 0 && !AMDGPU) continue; + if (h == 0 && test.height > 1024) continue; + FFTConfig fft{test, variant_WMH (1, 0, h), CARRY_32}; + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(); + log("Fast height search %6.1f %12s\n", cost, fft.spec().c_str()); + if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_height = h; } + } + fastest_height_variants[shape.height] = fastest_height; + } + if (skip_some_WH_variants == 2 && variant_H(variant) != fastest_height) continue; + if (skip_some_WH_variants == 1 && + FFTConfig{shape, variant, CARRY_32}.maxBpw() < + FFTConfig{shape, variant_WMH (variant_W(variant), variant_M(variant), fastest_height), CARRY_32}.maxBpw()) continue; + } + } + +//GW: If variant is specified on command line, time it (and only it)?? Or an option to only time one variant number?? + vector carryToTest{CARRY_32}; // We need to test both carry-32 and carry-64 only when the carry transition is within the BPW range. if (FFTConfig{shape, variant, CARRY_64}.maxBpw() > FFTConfig{shape, variant, CARRY_32}.maxBpw()) { @@ -285,14 +450,10 @@ void Tune::tune() { for (auto carry : carryToTest) { FFTConfig fft{shape, variant, carry}; - if (minCost > 0 && !TuneEntry{minCost, fft}.willUpdate(results)) { - // log("skipped %s %9u\n", fft.spec().c_str(), fft.maxExp()); - continue; - } + // Skip middle = 1, CARRY_32 if maximum exponent would be the same as middle = 0, CARRY_32 + if (variant_M(variant) > 0 && carry == CARRY_32 && fft.maxExp() <= FFTConfig{shape, variant - 10, CARRY_32}.maxExp()) continue; double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - if (minCost <= 0) { minCost = cost; } - bool isUseful = TuneEntry{cost, fft}.update(results); log("%c %6.1f %12s %9u\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); } diff --git a/src/tune.h b/src/tune.h index 5428d991..de50bf52 100644 --- a/src/tune.h +++ b/src/tune.h @@ -22,8 +22,8 @@ class Tune { GpuCommon shared; Primes primes; - std::pair maxBpw(FFTConfig fft); - double zForBpw(double bpw, FFTConfig fft); + double maxBpw(FFTConfig fft); + double zForBpw(double bpw, FFTConfig fft, u32); public: Tune(Queue *q, GpuCommon shared) : q{q}, shared{shared} {} From 47ef597f22f6ffcccfcfbfccaa355dc8605fc7d2 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 10 Jun 2025 08:59:46 +0000 Subject: [PATCH 005/115] Added -tune code to time all important -use options! --- src/TrigBufCache.cpp | 6 +- src/TrigBufCache.h | 2 +- src/tune.cpp | 246 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 242 insertions(+), 12 deletions(-) diff --git a/src/TrigBufCache.cpp b/src/TrigBufCache.cpp index c4194da0..7646a724 100644 --- a/src/TrigBufCache.cpp +++ b/src/TrigBufCache.cpp @@ -282,7 +282,7 @@ TrigBufCache::~TrigBufCache() = default; TrigPtr TrigBufCache::smallTrig(u32 W, u32 nW) { lock_guard lock{mut}; auto& m = small; - decay_t::key_type key{W, nW, 0, 0}; + decay_t::key_type key{W, nW, 0, 0, 0, 0}; TrigPtr p{}; auto it = m.find(key); @@ -300,9 +300,9 @@ TrigPtr TrigBufCache::smallTrigCombo(u32 width, u32 middle, u32 W, u32 nW, u32 v lock_guard lock{mut}; auto& m = small; - decay_t::key_type key1{W, nW, width, middle}; + decay_t::key_type key1{W, nW, width, middle, tail_single_wide, tail_trigs}; // We write the "combo" under two keys, so it can also be retrieved as non-combo by smallTrig() - decay_t::key_type key2{W, nW, 0, 0}; + decay_t::key_type key2{W, nW, 0, 0, 0, 0}; TrigPtr p{}; auto it = m.find(key1); diff --git a/src/TrigBufCache.h b/src/TrigBufCache.h index 4767da13..04ace546 100644 --- a/src/TrigBufCache.h +++ b/src/TrigBufCache.h @@ -27,7 +27,7 @@ class TrigBufCache { const Context* context; std::mutex mut; - std::map, TrigPtr::weak_type> small; + std::map, TrigPtr::weak_type> small; std::map, TrigPtr::weak_type> middle; // The shared-pointers below keep the most recent set of buffers alive even without any Gpu instance diff --git a/src/tune.cpp b/src/tune.cpp index 2fb622e2..b4fe9df1 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -327,9 +327,243 @@ void Tune::ctune() { void Tune::tune() { Args *args = shared.args; - string fftSpec = args->fftSpec; + vector shapes = FFTShape::multiSpec(args->fftSpec); + + // There are some options and variants that are different based on GPU manufacturer + bool AMDGPU = isAmdGpu(q->context->deviceId()); + + // Look for best settings of various options + + if (1) { + u32 variant = 101; +//GW: if fft spec on the command line specifies a variant then we should use that variant (I get some interesting results with 000 vs 101 vs 201 vs 202 likely due to rocm optimizer) + + // Find best FAST_BARRIER setting + if (1 && AMDGPU) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_fast_barrier = 0; + double best_cost = -1.0; + for (u32 fast_barrier : {0, 1}) { + shared.args->flags["FAST_BARRIER"] = to_string(fast_barrier); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using FAST_BARRIER=%u is %6.1f\n", fft.spec().c_str(), fast_barrier, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_fast_barrier = fast_barrier; } + } + log("Best FAST_BARRIER is %u. Default FAST_BARRIER is 0.\n", best_fast_barrier); + shared.args->flags["FAST_BARRIER"] = to_string(best_fast_barrier); + } + + // Find best TAIL_TRIGS setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_trigs = 0; + double best_cost = -1.0; + for (u32 tail_trigs : {0, 1, 2}) { + shared.args->flags["TAIL_TRIGS"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TAIL_TRIGS=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } + } + log("Best TAIL_TRIGS is %u. Default TAIL_TRIGS is 2.\n", best_tail_trigs); + shared.args->flags["TAIL_TRIGS"] = to_string(best_tail_trigs); + } + + // Find best TAIL_KERNELS setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_kernels = 0; + double best_cost = -1.0; + for (u32 tail_kernels : {0, 1, 2, 3}) { + shared.args->flags["TAIL_KERNELS"] = to_string(tail_kernels); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TAIL_KERNELS=%u is %6.1f\n", fft.spec().c_str(), tail_kernels, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_kernels = tail_kernels; } + } + if (best_tail_kernels & 1) + log("Best TAIL_KERNELS is %u. Default TAIL_KERNELS is 2.\n", best_tail_kernels); + else + log("Best TAIL_KERNELS is %u (but best may be %u when running two workers on one GPU). Default TAIL_KERNELS is 2.\n", best_tail_kernels, best_tail_kernels | 1); + shared.args->flags["TAIL_KERNELS"] = to_string(best_tail_kernels); + } -//GW: detail all the configs we should auto-time first + // Find best TABMUL_CHAIN setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, 101, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tabmul_chain = 0; + double best_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + shared.args->flags["TABMUL_CHAIN"] = to_string(tabmul_chain); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TABMUL_CHAIN=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } + } + log("Best TABMUL_CHAIN is %u. Default TABMUL_CHAIN is 0.\n", best_tabmul_chain); + shared.args->flags["TABMUL_CHAIN"] = to_string(best_tabmul_chain); + } + + // Find best PAD setting. Default is 256 bytes for AMD, 0 for all others. + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_pad = 0; + double best_cost = -1.0; + for (u32 pad : {0, 64, 128, 256, 512}) { + shared.args->flags["PAD"] = to_string(pad); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using PAD=%u is %6.1f\n", fft.spec().c_str(), pad, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_pad = pad; } + } + log("Best PAD is %u bytes. Default PAD is %u bytes.\n", best_pad, AMDGPU ? 256 : 0); + shared.args->flags["PAD"] = to_string(best_pad); + } + + // Find best NONTEMPORAL setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_nontemporal = 0; + double best_cost = -1.0; + for (u32 nontemporal : {0, 1}) { + shared.args->flags["NONTEMPORAL"] = to_string(nontemporal); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_nontemporal = nontemporal; } + } + log("Best NONTEMPORAL is %u. Default NONTEMPORAL is 0.\n", best_nontemporal); + shared.args->flags["NONTEMPORAL"] = to_string(best_nontemporal); + } + + // Find best UNROLL_W setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_unroll_w = 0; + double best_cost = -1.0; + for (u32 unroll_w : {0, 1}) { + shared.args->flags["UNROLL_W"] = to_string(unroll_w); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using UNROLL_W=%u is %6.1f\n", fft.spec().c_str(), unroll_w, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_w = unroll_w; } + } + log("Best UNROLL_W is %u. Default UNROLL_W is %u.\n", best_unroll_w, AMDGPU ? 0 : 1); + shared.args->flags["UNROLL_W"] = to_string(best_unroll_w); + } + + // Find best UNROLL_H setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_unroll_h = 0; + double best_cost = -1.0; + for (u32 unroll_h : {0, 1}) { + shared.args->flags["UNROLL_H"] = to_string(unroll_h); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using UNROLL_H=%u is %6.1f\n", fft.spec().c_str(), unroll_h, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_h = unroll_h; } + } + log("Best UNROLL_H is %u. Default UNROLL_H is %u.\n", best_unroll_h, AMDGPU && shape.height >= 1024 ? 0 : 1); + shared.args->flags["UNROLL_H"] = to_string(best_unroll_h); + } + + // Find best ZEROHACK_W setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_zerohack_w = 0; + double best_cost = -1.0; + for (u32 zerohack_w : {0, 1}) { + shared.args->flags["ZEROHACK_W"] = to_string(zerohack_w); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using ZEROHACK_W=%u is %6.1f\n", fft.spec().c_str(), zerohack_w, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_w = zerohack_w; } + } + log("Best ZEROHACK_W is %u. Default ZEROHACK_W is 1.\n", best_zerohack_w); + shared.args->flags["ZEROHACK_W"] = to_string(best_zerohack_w); + } + + // Find best ZEROHACK_H setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_zerohack_h = 0; + double best_cost = -1.0; + for (u32 zerohack_h : {0, 1}) { + shared.args->flags["ZEROHACK_H"] = to_string(zerohack_h); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using ZEROHACK_H=%u is %6.1f\n", fft.spec().c_str(), zerohack_h, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_h = zerohack_h; } + } + log("Best ZEROHACK_H is %u. Default ZEROHACK_H is 1.\n", best_zerohack_h); + shared.args->flags["ZEROHACK_H"] = to_string(best_zerohack_h); + } + + // Find best MIDDLE_IN_LDS_TRANSPOSE setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_middle_in_lds_transpose = 0; + double best_cost = -1.0; + for (u32 middle_in_lds_transpose : {0, 1}) { + shared.args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } + } + log("Best MIDDLE_IN_LDS_TRANSPOSE is %u. Default MIDDLE_IN_LDS_TRANSPOSE is 1.\n", best_middle_in_lds_transpose); + shared.args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); + } + + // Find best MIDDLE_OUT_LDS_TRANSPOSE setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_middle_out_lds_transpose = 0; + double best_cost = -1.0; + for (u32 middle_out_lds_transpose : {0, 1}) { + shared.args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } + } + log("Best MIDDLE_OUT_LDS_TRANSPOSE is %u. Default MIDDLE_OUT_LDS_TRANSPOSE is 1.\n", best_middle_out_lds_transpose); + shared.args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); + } + + // Find best BIGLIT setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_biglit = 0; + double best_cost = -1.0; + for (u32 biglit : {0, 1}) { + shared.args->flags["BIGLIT"] = to_string(biglit); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using BIGLIT=%u is %6.1f\n", fft.spec().c_str(), biglit, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_biglit = biglit; } + } + log("Best BIGLIT is %u. Default BIGLIT is 1. The BIGLIT=0 option will probably be deprecated.\n", best_biglit); + shared.args->flags["BIGLIT"] = to_string(best_biglit); + } + + //GW: Time some IN/OUT_WG/SIZEX combos? + } // Flags that prune the amount of shapes and variants to time. // These should be computed automatically and saved in the tune.txt or config.txt file. @@ -341,11 +575,8 @@ void Tune::tune() { // The width = height = 512 FFT shape is so good, we probably don't need to time the width = 1024, height = 256 shape. bool skip_1K_256 = 1; - // There are some variands only AMD GPUs can execute - bool AMDGPU = isAmdGpu(q->context->deviceId()); - -// make command line args for this? -skip_some_WH_variants = 2; +// make command line args for this? +skip_some_WH_variants = 2; // should default be 1?? skip_1K_256 = 0; //GW: Suggest tuning with TAIL_KERNELS=2 even if production runs use TAIL_KERNELS=3 @@ -359,7 +590,6 @@ skip_1K_256 = 0; map fastest_height_variants; vector results = TuneEntry::readTuneFile(*args); - vector shapes = FFTShape::multiSpec(args->fftSpec); // Loop through all possible FFT shapes for (const FFTShape& shape : shapes) { From 4de83d0d56cee9c3e49039e9035ae6bd37e61a83 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 14 Jun 2025 00:36:48 +0000 Subject: [PATCH 006/115] Added tune code for IN_WG,IN_SIZEX,OUT_WG,OUT_SIZEX --- src/tune.cpp | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/tune.cpp b/src/tune.cpp index b4fe9df1..ff1b76d6 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -338,6 +338,50 @@ void Tune::tune() { u32 variant = 101; //GW: if fft spec on the command line specifies a variant then we should use that variant (I get some interesting results with 000 vs 101 vs 201 vs 202 likely due to rocm optimizer) + // Find best IN_WG,IN_SIZEX setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_in_wg = 0; + u32 best_in_sizex = 0; + double best_cost = -1.0; + for (u32 in_wg : {64, 128, 256}) { + for (u32 in_sizex : {8, 16, 32}) { + shared.args->flags["IN_WG"] = to_string(in_wg); + shared.args->flags["IN_SIZEX"] = to_string(in_sizex); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using IN_WG=%u, IN_SIZEX=%u is %6.1f\n", fft.spec().c_str(), in_wg, in_sizex, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_in_wg = in_wg; best_in_sizex = in_sizex; } + } + } + log("Best IN_WG, IN_SIZEX is %u, %u. Default is 128, 16.\n", best_in_wg, best_in_sizex); + shared.args->flags["IN_WG"] = to_string(best_in_wg); + shared.args->flags["IN_SIZEX"] = to_string(best_in_sizex); + } + + // Find best OUT_WG,OUT_SIZEX setting + if (1) { + const FFTShape& shape = shapes[0]; + FFTConfig fft{shape, variant, CARRY_32}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_out_wg = 0; + u32 best_out_sizex = 0; + double best_cost = -1.0; + for (u32 out_wg : {64, 128, 256}) { + for (u32 out_sizex : {8, 16, 32}) { + shared.args->flags["OUT_WG"] = to_string(out_wg); + shared.args->flags["OUT_SIZEX"] = to_string(out_sizex); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using OUT_WG=%u, OUT_SIZEX=%u is %6.1f\n", fft.spec().c_str(), out_wg, out_sizex, cost); + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_out_wg = out_wg; best_out_sizex = out_sizex; } + } + } + log("Best OUT_WG, OUT_SIZEX is %u, %u. Default is 128, 16.\n", best_out_wg, best_out_sizex); + shared.args->flags["OUT_WG"] = to_string(best_out_wg); + shared.args->flags["OUT_SIZEX"] = to_string(best_out_sizex); + } + // Find best FAST_BARRIER setting if (1 && AMDGPU) { const FFTShape& shape = shapes[0]; From 4826949db6848f52d7d60535b37a4c0042ed6f60 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 23 Jun 2025 19:31:00 +0000 Subject: [PATCH 007/115] Solved the nVidia finish causing a CPU busy wait. The queue contains at most 500 kernels, a marker, and 500 more kernels. When prpll tries to add another kernel to the queue, we loop checking if the marker has been reached else performing a lengthy sleep (knowing there are 500 kernels to execute after the marker). --- src/Args.cpp | 3 --- src/Args.h | 1 - src/Gpu.cpp | 18 ++++++----------- src/Queue.cpp | 53 +++++++++++++++++++++++++++++++++++++++++++++++++-- src/Queue.h | 11 +++++++++++ src/tinycl.h | 3 ++- 6 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 57b66be2..2cc1e072 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -333,9 +333,6 @@ void Args::parse(const string& line) { if (workers < 1 || workers > 4) { throw "Number of workers must be between 1 and 4"; } - } else if (key == "-flush") { - flushStep = stoi(s); - assert(flushStep); } else if (key == "-cache") { useCache = true; } else if (key == "-noclean") { diff --git a/src/Args.h b/src/Args.h index 291b8d5b..22ff5844 100644 --- a/src/Args.h +++ b/src/Args.h @@ -75,7 +75,6 @@ class Args { u32 workers = 1; u32 blockSize = 1000; u32 logStep = 20000; - u32 flushStep = 400; string fftSpec; u32 prpExp = 0; diff --git a/src/Gpu.cpp b/src/Gpu.cpp index d72dde7f..6487dd54 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1504,15 +1504,13 @@ PRPResult Gpu::isPrimePRP(const Task& task) { bool doCheck = doStop || (k % checkStep == 0) || (k >= kEndEnd) || (k - startK == 2 * blockSize); bool doLog = k % logStep == 0; - if (!leadOut || (!doCheck && !doLog)) { - if (k % args.flushStep == 0) { queue->finish(); } - continue; - } + if (!leadOut || (!doCheck && !doLog)) continue; assert(doCheck || doLog); u64 res = dataResidue(); float secsPerIt = iterationTimer.reset(k); + queue->setSquareTime((int) (secsPerIt * 1'000'000)); vector rawCheck = readChecked(bufCheck); if (rawCheck.empty()) { @@ -1630,10 +1628,7 @@ LLResult Gpu::isPrimeLL(const Task& task) { squareLL(bufData, leadIn, leadOut); leadIn = leadOut; - if (!doLog) { - if (k % args.flushStep == 0) { queue->finish(); } // Periodically flush the queue - continue; - } + if (!doLog) continue; u64 res64 = 0; auto data = readData(); @@ -1656,6 +1651,7 @@ LLResult Gpu::isPrimeLL(const Task& task) { } float secsPerIt = iterationTimer.reset(k); + queue->setSquareTime((int) (secsPerIt * 1'000'000)); log("%9u %016" PRIx64 " %4.0f\n", k, res64, secsPerIt * 1'000'000); if (k >= kEnd) { return {isAllZero, res64}; } @@ -1707,16 +1703,14 @@ array Gpu::isCERT(const Task& task) { squareCERT(bufData, leadIn, leadOut); leadIn = leadOut; - if (!doLog) { - if (k % args.flushStep == 0) { queue->finish(); } // Periodically flush the queue - continue; - } + if (!doLog) continue; Words data = readData(); assert(data.size() >= 2); u64 res64 = (u64(data[1]) << 32) | data[0]; float secsPerIt = iterationTimer.reset(k); + queue->setSquareTime((int) (secsPerIt * 1'000'000)); log("%9u %016" PRIx64 " %4.0f\n", k, res64, secsPerIt * 1'000'000); if (k >= kEnd) { diff --git a/src/Queue.cpp b/src/Queue.cpp index 79be78d9..05b0ae68 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -18,8 +18,16 @@ void Events::synced() { Queue::Queue(const Context& context, bool profile) : QueueHolder{makeQueue(context.deviceId(), context.get(), profile)}, hasEvents{profile}, - context{&context} -{} + context{&context}, + markerEvent{}, + markerQueued(false), + queueCount(0), + squareTime(50) +{ + // Formerly a constant (thus the CAPS). nVidia is 3% CPU load at 400 or 500, and 35% load at 800 on my Linux machine. + // AMD is just over 2% load at 1600 and 3200 on the same Linux machine. Marginally better timings(?) at 3200. + MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings +} void Queue::writeTE(cl_mem buf, u64 size, const void* data, TimeInfo* tInfo) { add(::write(get(), {}, true, buf, size, data, hasEvents), tInfo); @@ -43,9 +51,12 @@ void Queue::print() { void Queue::add(EventHolder&& e, TimeInfo* ti) { if (hasEvents) { events.emplace_back(std::move(e), ti); } + queueCount++; + if (queueCount == MAX_QUEUE_COUNT) queueMarkerEvent(); } void Queue::readSync(cl_mem buf, u32 size, void* out, TimeInfo* tInfo) { + queueMarkerEvent(); add(read(get(), {}, true, buf, size, out, hasEvents), tInfo); events.synced(); } @@ -63,6 +74,44 @@ void Queue::run(cl_kernel kernel, size_t groupSize, size_t workSize, TimeInfo* t } void Queue::finish() { + waitForMarkerEvent(); ::finish(get()); events.synced(); + queueCount = 0; +} + +void Queue::queueMarkerEvent() { + waitForMarkerEvent(); + if (queueCount) { + // AMD GPUs have no trouble waiting for a finish without a CPU busy wait. So, instead of markers and events, simply run finish every now and then. + if (isAmdGpu(context->deviceId())) { + finish(); + } + // Enqueue a marker for nVidia GPUs + else { + clEnqueueMarkerWithWaitList(get(), 0, NULL, &markerEvent); + markerQueued = true; + queueCount = 0; + } + } +} + +void Queue::waitForMarkerEvent() { + if (!markerQueued) return; + // By default, nVidia finish causes a CPU busy wait. Instead, sleep for a while. Since we know how many items are enqueued after the marker we can make an + // educated guess of how long to sleep to keep CPU overhead low. + while (getEventInfo(markerEvent) != CL_COMPLETE) { +#if defined(__CYGWIN__) + sleep(1); // 1 second. A very steep overhead as 500 iterations won't take that long. +#else + usleep(1 + queueCount * squareTime / 10); // There are 4 kernels per squaring. Don't overestimate sleep time. Divide by 10 instead of 4. +#endif + } + markerQueued = false; +} + +void Queue::setSquareTime(int time) { + if (time < 30) time = 30; // Assume a minimum square time of 30us + if (time > 3000) time = 3000; // Assume a maximum square time of 3000us + squareTime = time; } diff --git a/src/Queue.h b/src/Queue.h index 917d45ee..06efa159 100644 --- a/src/Queue.h +++ b/src/Queue.h @@ -49,4 +49,15 @@ class Queue : public QueueHolder { void readAsync(cl_mem buf, u32 size, void* out, TimeInfo* tInfo); void copyBuf(cl_mem src, cl_mem dst, u32 size, TimeInfo* tInfo); void finish(); + + void setSquareTime(int); // Set the time to do one squaring (in microseconds) + +private: // This replaces the "call queue->finish every 400 squarings" code in Gpu.cpp. Solves the busy wait on nVidia GPUs. + int MAX_QUEUE_COUNT; // Queue size before a marker will be enqueued. Typically, 100 to 1000 squarings. + cl_event markerEvent; // Event associated with an enqueued marker placed in the queue every MAX_QUEUE_COUNT entries and before r/w operations. + bool markerQueued; // TRUE if a marker and event have been queued + int queueCount; // Count of items added to the queue since last marker + int squareTime; // Time to do one squaring (in microseconds) + void queueMarkerEvent(); // Queue the marker event + void waitForMarkerEvent(); // Wait for marker event to complete }; diff --git a/src/tinycl.h b/src/tinycl.h index 276d1e3a..e3317657 100644 --- a/src/tinycl.h +++ b/src/tinycl.h @@ -71,7 +71,8 @@ cl_command_queue clCreateCommandQueueWithPropertiesAPPLE(cl_context, cl_device_i #else cl_command_queue clCreateCommandQueueWithProperties(cl_context, cl_device_id, const cl_queue_properties *, int *); #endif - + +int clEnqueueMarkerWithWaitList(cl_command_queue, unsigned num_events_in_wait_list, const cl_event* event_wait_list, cl_event* event); int clEnqueueReadBuffer(cl_command_queue, cl_mem, cl_bool, size_t, size_t, void *, unsigned numEvents, const cl_event *waitEvents, cl_event *outEvent); int clEnqueueWriteBuffer(cl_command_queue, cl_mem, cl_bool, size_t, size_t, const void *, From 757c3f73d802ca4ae3836f825060f0bbe6acd9cd Mon Sep 17 00:00:00 2001 From: george Date: Mon, 23 Jun 2025 20:45:04 +0000 Subject: [PATCH 008/115] At Mihai's suggestion, use a portable method to sleep N microseconds. --- src/Queue.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Queue.cpp b/src/Queue.cpp index 05b0ae68..d2c1ceac 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -7,6 +7,8 @@ #include "log.h" #include +#include +#include void Events::clearCompleted() { while (!empty() && front().isComplete()) { pop_front(); } } @@ -101,11 +103,8 @@ void Queue::waitForMarkerEvent() { // By default, nVidia finish causes a CPU busy wait. Instead, sleep for a while. Since we know how many items are enqueued after the marker we can make an // educated guess of how long to sleep to keep CPU overhead low. while (getEventInfo(markerEvent) != CL_COMPLETE) { -#if defined(__CYGWIN__) - sleep(1); // 1 second. A very steep overhead as 500 iterations won't take that long. -#else - usleep(1 + queueCount * squareTime / 10); // There are 4 kernels per squaring. Don't overestimate sleep time. Divide by 10 instead of 4. -#endif + // There are 4 kernels per squaring. Don't overestimate sleep time. Divide by 10 instead of 4. + std::this_thread::sleep_for(std::chrono::microseconds(1 + queueCount * squareTime / 10)); } markerQueued = false; } From d198fd5b3f09bc98c22d999f80dd60c05b80998c Mon Sep 17 00:00:00 2001 From: george Date: Mon, 30 Jun 2025 00:53:13 +0000 Subject: [PATCH 009/115] Fixed memory allocation bug with MIDDLE=4, PAD=512 --- src/Gpu.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 6487dd54..3b0f980b 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -492,7 +492,8 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& BUF(bufStatsCarry, CARRY_SIZE), // Allocate extra for padding. We can probably tighten up the amount of extra memory allocated. - #define total_padding ((pad_size == 0 ? 0 : (pad_size <= 128 ? N/8 : (pad_size <= 256 ? N/4 : N/2)))) + // The worst case seems to be MIDDLE=4, PAD_SIZE=512 + #define total_padding (((pad_size == 0 ? 0 : (pad_size <= 128 ? N/8 : (pad_size <= 256 ? N/4 : N/2)))) * (fft.shape.middle == 4 ? 5 : 4) / 4) BUF(buf1, N + total_padding), BUF(buf2, N + total_padding), BUF(buf3, N + total_padding), @@ -1365,6 +1366,7 @@ double Gpu::timePRP() { if (Signal::stopRequested()) { throw "stop requested"; } Timer t; + queue->setSquareTime(0); // Busy wait on nVidia to get the most accurate timings while tuning bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { From fd59830709d2940b696fa2c3db1b9f1c78c15d2e Mon Sep 17 00:00:00 2001 From: george Date: Sat, 5 Jul 2025 18:41:53 +0000 Subject: [PATCH 010/115] Changed results.txt to results-N.txt -- the same naming scheme as worktodo files. Autoprimenet needs this change. --- src/Args.cpp | 8 ++------ src/Args.h | 1 - src/Task.cpp | 24 +++++++++++++----------- src/Task.h | 6 +++--- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 2cc1e072..17aa6738 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -160,7 +160,6 @@ named "config.txt" in the prpll run directory. A lower power reduces disk space requirements but increases the verification cost. A higher power increases disk usage a lot. e.g. proof power 10 for a 120M exponent uses about %.0fGB of disk space. --results : name of results file, default '%s' -iters : run next PRP test for iterations and exit. Multiple of 10000. -save : specify the number of savefiles to keep (default %u). -noclean : do not delete data after the test is complete. @@ -218,7 +217,7 @@ named "config.txt" in the prpll run directory. Device selection : use one of -uid , -pci , -device , see the list below -)", ProofSet::diskUsageGB(120000000, 10), resultsFile.string().c_str(), nSavefiles); +)", ProofSet::diskUsageGB(120000000, 10), nSavefiles); vector deviceIds = getAllDeviceIDs(); if (!deviceIds.empty()) { @@ -365,7 +364,6 @@ void Args::parse(const string& line) { throw("-pool requires an absolute path"); } } - else if (key == "-results") { resultsFile = s; } else if (key == "-maxAlloc" || key == "-maxalloc") { assert(!s.empty()); u32 multiple = (s.back() == 'G') ? (1u << 30) : (1u << 20); @@ -425,12 +423,10 @@ void Args::setDefaults() { if (!masterDir.empty()) { assert(masterDir.is_absolute()); - for (filesystem::path* p : {&proofResultDir, &proofToVerifyDir, &cacheDir, &resultsFile}) { + for (filesystem::path* p : {&proofResultDir, &proofToVerifyDir, &cacheDir}) { if (p->is_relative()) { *p = masterDir / *p; } } } for (auto& p : {proofResultDir, proofToVerifyDir, cacheDir}) { fs::create_directory(p); } - - File::openAppend(resultsFile); // verify that it's possible to write results } diff --git a/src/Args.h b/src/Args.h index 22ff5844..c7e92259 100644 --- a/src/Args.h +++ b/src/Args.h @@ -66,7 +66,6 @@ class Args { fs::path proofResultDir = "proof"; fs::path proofToVerifyDir = "proof-tmp"; fs::path cacheDir = "kernel-cache"; - fs::path resultsFile = "results.txt"; // fs::path tuneFile = "tune.txt"; bool keepProof = false; diff --git a/src/Task.cpp b/src/Task.cpp index 6056c7e0..27f824c4 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -132,19 +132,20 @@ vector tailFields(const std::string &AID, const Args &args) { }; } -void writeResult(u32 E, const char *workType, const string &status, const std::string &AID, const Args &args, +void writeResult(u32 instance, u32 E, const char *workType, const string &status, const std::string &AID, const Args &args, const vector& extras) { + fs::path resultsFile = "results-" + to_string(instance) + ".txt"; vector fields = commonFields(E, workType, status); fields += extras; fields += tailFields(AID, args); string s = json(std::move(fields)); log("%s\n", s.c_str()); - File::append(args.resultsFile, s + '\n'); + File::append(resultsFile, s + '\n'); } } -void Task::writeResultPRP(const Args &args, bool isPrime, u64 res64, const string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const { +void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res64, const string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const { vector fields{json("res64", hex(res64)), json("res2048", res2048), json("residue-type", 1), @@ -165,20 +166,20 @@ void Task::writeResultPRP(const Args &args, bool isPrime, u64 res64, const strin } } - writeResult(exponent, "PRP-3", isPrime ? "P" : "C", AID, args, fields); + writeResult(instance, exponent, "PRP-3", isPrime ? "P" : "C", AID, args, fields); } -void Task::writeResultLL(const Args &args, bool isPrime, u64 res64, u32 fftSize) const { +void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64, u32 fftSize) const { vector fields{json("res64", hex(res64)), json("fft-length", fftSize), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this }; - writeResult(exponent, "LL", isPrime ? "P" : "C", AID, args, fields); + writeResult(instance, exponent, "LL", isPrime ? "P" : "C", AID, args, fields); } -void Task::writeResultCERT(const Args &args, array hash, u32 squarings, u32 fftSize) const { +void Task::writeResultCERT(const Args &args, u32 instance, array hash, u32 squarings, u32 fftSize) const { string hexhash = hex(hash[3]) + hex(hash[2]) + hex(hash[1]) + hex(hash[0]); vector fields{json("worktype", "Cert"), json("exponent", exponent), @@ -191,7 +192,8 @@ void Task::writeResultCERT(const Args &args, array hash, u32 squarings, fields += tailFields(AID, args); string s = json(std::move(fields)); log("%s\n", s.c_str()); - File::append(args.resultsFile, s + '\n'); + fs::path resultsFile = "results-" + to_string(instance) + ".txt"; + File::append(resultsFile, s + '\n'); } void Task::execute(GpuCommon shared, Queue *q, u32 instance) { @@ -216,11 +218,11 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { if (kind == PRP) { auto [tmpIsPrime, res64, nErrors, proofPath, res2048] = gpu->isPrimePRP(*this); isPrime = tmpIsPrime; - writeResultPRP(*shared.args, isPrime, res64, res2048, fft.size(), nErrors, proofPath); + writeResultPRP(*shared.args, instance, isPrime, res64, res2048, fft.size(), nErrors, proofPath); } else { // LL auto [tmpIsPrime, res64] = gpu->isPrimeLL(*this); isPrime = tmpIsPrime; - writeResultLL(*shared.args, isPrime, res64, fft.size()); + writeResultLL(*shared.args, instance, isPrime, res64, fft.size()); } Worktodo::deleteTask(*this, instance); @@ -232,7 +234,7 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { } } else if (kind == CERT) { auto sha256 = gpu->isCERT(*this); - writeResultCERT(*shared.args, sha256, squarings, fft.size()); + writeResultCERT(*shared.args, instance, sha256, squarings, fft.size()); Worktodo::deleteTask(*this, instance); } else { throw "Unexpected task kind " + to_string(kind); diff --git a/src/Task.h b/src/Task.h index 914e364c..6d870263 100644 --- a/src/Task.h +++ b/src/Task.h @@ -27,7 +27,7 @@ class Task { string verifyPath; // For Verify void execute(GpuCommon shared, Queue* q, u32 instance); - void writeResultPRP(const Args&, bool isPrime, u64 res64, const std::string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const; - void writeResultLL(const Args&, bool isPrime, u64 res64, u32 fftSize) const; - void writeResultCERT(const Args&, array hash, u32 squarings, u32 fftSize) const; + void writeResultPRP(const Args&, u32 instance, bool isPrime, u64 res64, const std::string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const; + void writeResultLL(const Args&, u32 instance, bool isPrime, u64 res64, u32 fftSize) const; + void writeResultCERT(const Args&, u32 instance, array hash, u32 squarings, u32 fftSize) const; }; From 0e25bbf0cca0577eed133218727f6e54ab7bd223 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 9 Jul 2025 01:18:06 +0000 Subject: [PATCH 011/115] Fixed bug where an AMD-only FFT variant (000) was chosen whenever no useful FFT spec could be found in tune.txt --- src/FFTConfig.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index ae34cfa9..d4dd6d69 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -251,7 +251,7 @@ FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { // Take the first FFT that can handle E for (const FFTShape& shape : FFTShape::allShapes()) { - for (u32 v = 0; v < 4; ++v) { + for (u32 v : {101, 202}) { if (FFTConfig fft{shape, v, CARRY_AUTO}; fft.maxExp() * args.fftOverdrive >= E) { return fft; } } } From 7afe0a572ccdd8858e02d27b157c1047db729d0e Mon Sep 17 00:00:00 2001 From: george Date: Fri, 11 Jul 2025 23:22:57 +0000 Subject: [PATCH 012/115] Don't output tedious "missing BPW" message when that is expected --- src/FFTConfig.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index d4dd6d69..669f7eea 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -156,9 +156,11 @@ FFTShape::FFTShape(u32 w, u32 m, u32 h) : while (h < 256) { h *= 2; m /= 2; } bpw = FFTShape{w, m, h}.bpw; for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) bpw[j] -= 0.05; // Assume this fft spec is worse than measured fft specs - printf("BPW info for %s not found, defaults={", s.c_str()); - for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", bpw[j]); - printf("}\n"); + if (this->isFavoredShape()) { // Don't output this warning message for non-favored shapes (we expect the BPW info to be missing) + printf("BPW info for %s not found, defaults={", s.c_str()); + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", bpw[j]); + printf("}\n"); + } } } } From 09fea7988241e85cf9a60cb9649e8400fd3771fc Mon Sep 17 00:00:00 2001 From: george Date: Fri, 11 Jul 2025 23:24:26 +0000 Subject: [PATCH 013/115] Fixed compiler warning for Windows compiler (long vs. long long difference) --- src/clwrap.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/clwrap.cpp b/src/clwrap.cpp index 2bbedc7b..29cdab5d 100644 --- a/src/clwrap.cpp +++ b/src/clwrap.cpp @@ -231,7 +231,7 @@ string getBuildLog(cl_program program, cl_device_id deviceId) { if (logSize > 0) { // Avoid printing excessively large compile logs if (logSize > maxLogSize) { - log("getBuildLog: log size is %lu bytes, not showing\n", logSize); + log("getBuildLog: log size is %lu bytes, not showing\n", (unsigned long) logSize); return {}; } std::unique_ptr buf(new char[logSize + 1]); From 60c991d36694d08250b126fd7f9b2f4e920e7ff9 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 5 Aug 2025 01:02:19 +0000 Subject: [PATCH 014/115] Corrected comments on UNROLL_W. Simplified biglit1 comment. --- src/cl/carryfused.cl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index bf905597..46bba3c3 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -19,7 +19,7 @@ void spin() { // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, - CP(u32) bits, ConstBigTab CONST_THREAD_WEIGHTS, BigTab THREAD_WEIGHTS, P(uint) bufROE) { + CP(u32) bits, ConstBigTab CONST_THREAD_WEIGHTS, BigTab THREAD_WEIGHTS, P(uint) bufROE) { #if 0 // fft_WIDTH uses shufl_int instead of shufl local T2 lds[WIDTH / 4]; @@ -51,11 +51,6 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs // which causes a terrible reduction in occupancy. // fft_WIDTH(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072)); - -// A temporary hack until we figure out which combinations we want to finally offer: -// UNROLL_W=0: old fft_WIDTH, no loop unrolling -// UNROLL_W=1: old fft_WIDTH, loop unrolling -// UNROLL_W=3: new fft_WIDTH if applicable. Slightly better on Radeon VII -- more study needed as to why results weren't better. #if ZEROHACK_W new_fft_WIDTH1(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072)); #else @@ -105,10 +100,8 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate big-word/little-word flags #if BIGLIT -// bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; -// bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP + FRAC_BPW_HI <= FRAC_BPW_HI; bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; - bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP >= -FRAC_BPW_HI; + bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP >= -FRAC_BPW_HI; // Same as frac_bits + i * FRAC_BITS_BIGSTEP + FRAC_BPW_HI <= FRAC_BPW_HI; #else bool biglit0 = test(b, 2 * i); bool biglit1 = test(b, 2 * i + 1); From 957f0cb986f8b80de54da1631645a096b2c32653 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 5 Aug 2025 01:03:15 +0000 Subject: [PATCH 015/115] Fixed bug in -ctune where it used a BCAST FFT variant which is not supported by nVidia GPUs. --- src/tune.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index ff1b76d6..e14fbe46 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -290,7 +290,7 @@ void Tune::ctune() { } for (FFTShape shape : shapes) { - FFTConfig fft{shape, 0, CARRY_32}; + FFTConfig fft{shape, 101, CARRY_32}; u32 exponent = primes.prevPrime(fft.maxExp()); // log("tuning %10s with exponent %u\n", fft.shape.spec().c_str(), exponent); From 05d0d7309512244900ee21aa1acc5fcdc9c62068 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 12 Aug 2025 17:43:11 +0000 Subject: [PATCH 016/115] Initialize all varibles before initializing the thread. Prevents a race condition were launched thread accesses uninitialized variables. --- src/Background.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Background.h b/src/Background.h index 08f4d079..96cbe804 100644 --- a/src/Background.h +++ b/src/Background.h @@ -16,10 +16,10 @@ class Background { unsigned maxSize; std::deque > tasks; - std::jthread thread; std::mutex mut; std::condition_variable cond; - bool stopRequested{}; + bool stopRequested; + std::jthread thread; void run() { std::function task; @@ -59,6 +59,7 @@ class Background { public: Background(unsigned size = 2) : maxSize{size}, + stopRequested(false), thread{&Background::run, this} { } From 10938a15c61d74cd2dd7d6a75086efbc043fe52b Mon Sep 17 00:00:00 2001 From: george Date: Sun, 14 Sep 2025 01:18:35 +0000 Subject: [PATCH 017/115] Massive changes to support NTTs and hybrid FFTs. First cut, much cleanup work remains. --- Makefile | 5 +- src/Buffer.h | 2 +- src/Gpu.cpp | 590 ++++++++++---- src/Gpu.h | 207 +++-- src/Proof.cpp | 2 +- src/Queue.cpp | 1 - src/Task.cpp | 29 +- src/TrigBufCache.cpp | 664 +++++++++++++++- src/TrigBufCache.h | 60 +- src/cl/base.cl | 114 ++- src/cl/carry.cl | 550 ++++++++++++- src/cl/carryb.cl | 21 +- src/cl/carryfused.cl | 1656 +++++++++++++++++++++++++++++++++++++++- src/cl/carryinc.cl | 239 +++++- src/cl/carryutil.cl | 556 ++++++++++++-- src/cl/etc.cl | 5 +- src/cl/fft-middle.cl | 568 +++++++++++++- src/cl/fft10.cl | 4 + src/cl/fft11.cl | 4 + src/cl/fft12.cl | 4 + src/cl/fft13.cl | 4 + src/cl/fft14.cl | 4 + src/cl/fft15.cl | 4 + src/cl/fft16.cl | 133 +++- src/cl/fft3.cl | 4 + src/cl/fft4.cl | 224 +++++- src/cl/fft5.cl | 4 + src/cl/fft6.cl | 4 + src/cl/fft7.cl | 4 + src/cl/fft8.cl | 164 +++- src/cl/fft9.cl | 4 + src/cl/fftbase.cl | 392 +++++++++- src/cl/fftheight.cl | 182 ++++- src/cl/ffthin.cl | 98 ++- src/cl/fftmiddlein.cl | 173 ++++- src/cl/fftmiddleout.cl | 187 +++++ src/cl/fftp.cl | 451 ++++++++++- src/cl/fftw.cl | 84 +- src/cl/fftwidth.cl | 134 +++- src/cl/math.cl | 742 +++++++++++++++++- src/cl/middle.cl | 326 +++++++- src/cl/selftest.cl | 47 +- src/cl/tailmul.cl | 431 ++++++++++- src/cl/tailsquare.cl | 941 ++++++++++++++++++++++- src/cl/tailutil.cl | 526 ++++++++++++- src/cl/transpose.cl | 51 +- src/cl/trig.cl | 89 ++- src/cl/weight.cl | 64 ++ src/common.h | 26 + src/main.cpp | 21 +- src/state.cpp | 64 +- src/state.h | 4 +- 52 files changed, 10186 insertions(+), 681 deletions(-) diff --git a/Makefile b/Makefile index f62c5266..e9c62f34 100644 --- a/Makefile +++ b/Makefile @@ -19,8 +19,9 @@ else CXX = g++ endif -COMMON_FLAGS = -Wall -std=c++20 -# -static-libstdc++ -static-libgcc +COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc +# For mingw-64 use this: +#COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc -static # -fext-numeric-literals ifeq ($(HOST_OS), Darwin) diff --git a/src/Buffer.h b/src/Buffer.h index c172c883..7866975d 100644 --- a/src/Buffer.h +++ b/src/Buffer.h @@ -26,7 +26,7 @@ class Buffer { Queue* queue; TimeInfo *tInfo; - + Buffer(cl_context context, TimeInfo *tInfo, Queue* queue, size_t size, unsigned flags, const T* ptr = nullptr) : ptr{size == 0 ? NULL : makeBuf_(context, flags, size * sizeof(T), ptr)} , size{size} diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 3b0f980b..a7c0a7e1 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -35,8 +35,12 @@ #define M_PI 3.141592653589793238462643383279502884 #endif +#define CARRY_LEN 8 + namespace { +#if FFT_FP64 + u32 kAt(u32 H, u32 line, u32 col) { return (line + col * H) * 2; } double weight(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { @@ -57,8 +61,6 @@ double invWeightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { double boundUnderOne(double x) { return std::min(x, nexttoward(1, 0)); } -#define CARRY_LEN 8 - Weights genWeights(u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { u32 N = 2u * W * H; @@ -104,36 +106,93 @@ Weights genWeights(u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { } assert(bits.size() == N / 32); - vector bitsC; + return Weights{weightsConstIF, weightsIF, bits}; +} + +#endif + +#if FFT_FP32 + +u32 kAt(u32 H, u32 line, u32 col) { return (line + col * H) * 2; } + +float weight(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2((double)(extra(N, E, kAt(H, line, col) + rep)) / N); +} + +float invWeight(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2(-(double)(extra(N, E, kAt(H, line, col) + rep)) / N); +} + +float weightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2((double)(extra(N, E, kAt(H, line, col) + rep)) / N) - 1; +} + +float invWeightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2(- (double)(extra(N, E, kAt(H, line, col) + rep)) / N) - 1; +} + +float boundUnderOne(float x) { return std::min(x, nexttowardf(1, 0)); } + +Weights genWeights(u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { + u32 N = 2u * W * H; + + u32 groupWidth = W / nW; + + // Inverse + Forward + vector weightsConstIF; + vector weightsIF; + for (u32 thread = 0; thread < groupWidth; ++thread) { + auto iw = invWeight(N, E, H, 0, thread, 0); + auto w = weight(N, E, H, 0, thread, 0); + // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer + // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF + // buffer, so there is an unfortunate duplication of these weights. + if (!AmdGpu) { + weightsConstIF.push_back(2 * boundUnderOne(iw)); + weightsConstIF.push_back(2 * w); + } + weightsIF.push_back(2 * boundUnderOne(iw)); + weightsIF.push_back(2 * w); + } + + // the group order matches CarryA/M (not fftP/CarryFused). + for (u32 gy = 0; gy < H; ++gy) { + weightsIF.push_back(invWeightM1(N, E, H, gy, 0, 0)); + weightsIF.push_back(weightM1(N, E, H, gy, 0, 0)); + } - for (u32 gy = 0; gy < H / CARRY_LEN; ++gy) { - for (u32 gx = 0; gx < nW; ++gx) { - for (u32 thread = 0; thread < groupWidth; ) { - std::bitset<32> b; - for (u32 bitoffset = 0; bitoffset < 32; bitoffset += CARRY_LEN * 2, ++thread) { - for (u32 block = 0; block < CARRY_LEN; ++block) { - for (u32 rep = 0; rep < 2; ++rep) { - if (isBigWord(N, E, kAt(H, gy * CARRY_LEN + block, gx * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } - } - } + vector bits; + + for (u32 line = 0; line < H; ++line) { + for (u32 thread = 0; thread < groupWidth; ) { + std::bitset<32> b; + for (u32 bitoffset = 0; bitoffset < 32; bitoffset += nW*2, ++thread) { + for (u32 block = 0; block < nW; ++block) { + for (u32 rep = 0; rep < 2; ++rep) { + if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } + } } - bitsC.push_back(b.to_ulong()); } + bits.push_back(b.to_ulong()); } } - assert(bitsC.size() == N / 32); + assert(bits.size() == N / 32); - return Weights{weightsConstIF, weightsIF, bits, bitsC}; + return Weights{weightsConstIF, weightsIF, bits}; } -string toLiteral(u32 value) { return to_string(value) + 'u'; } +#endif + string toLiteral(i32 value) { return to_string(value); } +string toLiteral(u32 value) { return to_string(value) + 'u'; } +[[maybe_unused]] string toLiteral(i64 value) { return to_string(value) + "l"; } [[maybe_unused]] string toLiteral(u64 value) { return to_string(value) + "ul"; } template string toLiteral(F value) { std::ostringstream ss; ss << std::setprecision(numeric_limits::max_digits10) << value; + if (sizeof(F) == 4) ss << "f"; string s = std::move(ss).str(); // verify exact roundtrip @@ -167,7 +226,12 @@ string toLiteral(const std::array& v) { string toLiteral(const string& s) { return s; } -string toLiteral(double2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(float2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(double2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(int2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(long2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(uint2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(ulong2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } template string toDefine(const string& k, T v) { return " -D"s + k + '=' + toLiteral(v); } @@ -266,16 +330,66 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< } u32 N = fft.shape.size(); + defines += toDefine("FFT_VARIANT", fft.variant); +#if FFT_FP64 | FFT_FP32 defines += toDefine("WEIGHT_STEP", weightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); defines += toDefine("IWEIGHT_STEP", invWeightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); - defines += toDefine("FFT_VARIANT", fft.variant); defines += toDefine("TAILT", root1Fancy(fft.shape.height * 2, 1)); TrigCoefs coefs = trigCoefs(fft.shape.size() / 4); defines += toDefine("TRIG_SCALE", int(coefs.scale)); defines += toDefine("TRIG_SIN", coefs.sinCoefs); defines += toDefine("TRIG_COS", coefs.cosCoefs); +#endif +#if NTT_GF31 + defines += toDefine("TAILTGF31", root1GF31(fft.shape.height * 2, 1)); +#endif +#if NTT_GF61 + defines += toDefine("TAILTGF61", root1GF61(fft.shape.height * 2, 1)); +#endif + +// When using multiple NTT primes or hybrid FFT/NTT, each FFT/NTT prime's data buffer and trig values are combined into one buffer. +// The openCL code needs to know the offset to the data and trig values. Distances are in "number of double2 values". +#if FFT_FP64 & NTT_GF31 + // GF31 data is located after the FP64 data. Compute size of the FP64 data and trigs. + defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); +#elif FFT_FP32 & NTT_GF31 + // GF31 data is located after the FP32 data. Compute size of the FP32 data and trigs. + defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); +#elif FFT_FP32 & NTT_GF61 + // GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. + defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); +#elif NTT_GF31 & NTT_GF61 + defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTWTRIGGF31", 0); + defines += toDefine("DISTMTRIGGF31", 0); + defines += toDefine("DISTHTRIGGF31", 0); + // GF61 data is located after the GF31 data. Compute size of the GF31 data and trigs. + defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); +#elif NTT_GF31 + defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTWTRIGGF31", 0); + defines += toDefine("DISTMTRIGGF31", 0); + defines += toDefine("DISTHTRIGGF31", 0); +#elif NTT_GF61 + defines += toDefine("DISTGF61", 0); + defines += toDefine("DISTWTRIGGF61", 0); + defines += toDefine("DISTMTRIGGF61", 0); + defines += toDefine("DISTHTRIGGF61", 0); +#endif // Calculate fractional bits-per-word = (E % N) / N * 2^64 u32 bpw_hi = (u64(E % N) << 32) / N; @@ -403,86 +517,110 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& hN(N / 2), nW(fft.shape.nW()), nH(fft.shape.nH()), - bufSize(N * sizeof(double)), useLongCarry{args.carry == Args::CARRY_LONG}, compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, tail_single_wide, tail_single_kernel, tail_trigs, pad_size)}, #define K(name, ...) name(#name, &compiler, profile.make(#name), queue, __VA_ARGS__) - // W / nW - K(kCarryFused, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW), - K(kCarryFusedROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DROE=1"), - - K(kCarryFusedMul, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1"), - K(kCarryFusedMulROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1 -DROE=1"), - - K(kCarryFusedLL, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DLL=1"), - - K(kCarryA, "carry.cl", "carry", hN / CARRY_LEN), - K(kCarryAROE, "carry.cl", "carry", hN / CARRY_LEN, "-DROE=1"), - - K(kCarryM, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1"), - K(kCarryMROE, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1 -DROE=1"), +#if FFT_FP64 | FFT_FP32 + K(kfftMidIn, "fftmiddlein.cl", "fftMiddleIn", hN / (BIG_H / SMALL_H)), + K(kfftHin, "ffthin.cl", "fftHin", hN / nH), + K(ktailSquareZero, "tailsquare.cl", "tailSquareZero", SMALL_H / nH * 2), + K(ktailSquare, "tailsquare.cl", "tailSquare", + !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels + !tail_single_wide ? hN / nH : // Double-wide tailSquare with one kernel + !tail_single_kernel ? hN / nH / 2 - SMALL_H / nH : // Single-wide tailSquare with two kernels + hN / nH / 2), // Single-wide tailSquare with one kernel + K(ktailMul, "tailmul.cl", "tailMul", hN / nH / 2), + K(ktailMulLow, "tailmul.cl", "tailMul", hN / nH / 2, "-DMUL_LOW=1"), + K(kfftMidOut, "fftmiddleout.cl", "fftMiddleOut", hN / (BIG_H / SMALL_H)), + K(kfftW, "fftw.cl", "fftW", hN / nW), +#endif - K(kCarryLL, "carry.cl", "carry", hN / CARRY_LEN, "-DLL=1"), - K(carryB, "carryb.cl", "carryB", hN / CARRY_LEN), +#if NTT_GF31 + K(kfftMidInGF31, "fftmiddlein.cl", "fftMiddleInGF31", hN / (BIG_H / SMALL_H)), + K(kfftHinGF31, "ffthin.cl", "fftHinGF31", hN / nH), + K(ktailSquareZeroGF31, "tailsquare.cl", "tailSquareZeroGF31", SMALL_H / nH * 2), + K(ktailSquareGF31, "tailsquare.cl", "tailSquareGF31", + !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels + !tail_single_wide ? hN / nH : // Double-wide tailSquare with one kernel + !tail_single_kernel ? hN / nH / 2 - SMALL_H / nH : // Single-wide tailSquare with two kernels + hN / nH / 2), // Single-wide tailSquare with one kernel + K(ktailMulGF31, "tailmul.cl", "tailMulGF31", hN / nH / 2), + K(ktailMulLowGF31, "tailmul.cl", "tailMulGF31", hN / nH / 2, "-DMUL_LOW=1"), + K(kfftMidOutGF31, "fftmiddleout.cl", "fftMiddleOutGF31", hN / (BIG_H / SMALL_H)), + K(kfftWGF31, "fftw.cl", "fftWGF31", hN / nW), +#endif - K(fftP, "fftp.cl", "fftP", hN / nW), - K(fftW, "fftw.cl", "fftW", hN / nW), - - // SMALL_H / nH - K(fftHin, "ffthin.cl", "fftHin", hN / nH), - K(tailSquareZero, "tailsquare.cl", "tailSquareZero", SMALL_H / nH * 2), - K(tailSquare, "tailsquare.cl", "tailSquare", !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels +#if NTT_GF61 + K(kfftMidInGF61, "fftmiddlein.cl", "fftMiddleInGF61", hN / (BIG_H / SMALL_H)), + K(kfftHinGF61, "ffthin.cl", "fftHinGF61", hN / nH), + K(ktailSquareZeroGF61, "tailsquare.cl", "tailSquareZeroGF61", SMALL_H / nH * 2), + K(ktailSquareGF61, "tailsquare.cl", "tailSquareGF61", + !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels !tail_single_wide ? hN / nH : // Double-wide tailSquare with one kernel !tail_single_kernel ? hN / nH / 2 - SMALL_H / nH : // Single-wide tailSquare with two kernels hN / nH / 2), // Single-wide tailSquare with one kernel + K(ktailMulGF61, "tailmul.cl", "tailMulGF61", hN / nH / 2), + K(ktailMulLowGF61, "tailmul.cl", "tailMulGF61", hN / nH / 2, "-DMUL_LOW=1"), + K(kfftMidOutGF61, "fftmiddleout.cl", "fftMiddleOutGF61", hN / (BIG_H / SMALL_H)), + K(kfftWGF61, "fftw.cl", "fftWGF61", hN / nW), +#endif + + K(kfftP, "fftp.cl", "fftP", hN / nW), + K(kCarryA, "carry.cl", "carry", hN / CARRY_LEN), + K(kCarryAROE, "carry.cl", "carry", hN / CARRY_LEN, "-DROE=1"), + K(kCarryM, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1"), + K(kCarryMROE, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1 -DROE=1"), + K(kCarryLL, "carry.cl", "carry", hN / CARRY_LEN, "-DLL=1"), + K(kCarryFused, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW), + K(kCarryFusedROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DROE=1"), + K(kCarryFusedMul, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1"), + K(kCarryFusedMulROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1 -DROE=1"), + K(kCarryFusedLL, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DLL=1"), + + K(carryB, "carryb.cl", "carryB", hN / CARRY_LEN), - K(tailMul, "tailmul.cl", "tailMul", hN / nH / 2), - K(tailMulLow, "tailmul.cl", "tailMul", hN / nH / 2, "-DMUL_LOW=1"), - - // 256 - K(fftMidIn, "fftmiddlein.cl", "fftMiddleIn", hN / (BIG_H / SMALL_H)), - K(fftMidOut, "fftmiddleout.cl", "fftMiddleOut", hN / (BIG_H / SMALL_H)), - // 64 K(transpIn, "transpose.cl", "transposeIn", hN / 64), K(transpOut, "transpose.cl", "transposeOut", hN / 64), - + K(readResidue, "etc.cl", "readResidue", 32, "-DREADRESIDUE=1"), // 256 K(kernIsEqual, "etc.cl", "isEqual", 256 * 256, "-DISEQUAL=1"), K(sum64, "etc.cl", "sum64", 256 * 256, "-DSUM64=1"), + +#if FFT_FP64 K(testTrig, "selftest.cl", "testTrig", 256 * 256), K(testFFT4, "selftest.cl", "testFFT4", 256), - K(testFFT, "selftest.cl", "testFFT", 256), - K(testFFT15, "selftest.cl", "testFFT15", 256), K(testFFT14, "selftest.cl", "testFFT14", 256), + K(testFFT15, "selftest.cl", "testFFT15", 256), + K(testFFT, "selftest.cl", "testFFT", 256), +#endif K(testTime, "selftest.cl", "testTime", 4096 * 64), + #undef K - bufTrigW{shared.bufCache->smallTrig(WIDTH, nW)}, bufTrigH{shared.bufCache->smallTrigCombo(WIDTH, fft.shape.middle, SMALL_H, nH, fft.variant, tail_single_wide, tail_trigs)}, bufTrigM{shared.bufCache->middleTrig(SMALL_H, BIG_H / SMALL_H, WIDTH)}, + bufTrigW{shared.bufCache->smallTrig(WIDTH, nW, fft.shape.middle, SMALL_H, nH, fft.variant, tail_single_wide, tail_trigs)}, +#if FFT_FP64 | FFT_FP32 weights{genWeights(E, WIDTH, BIG_H, nW, isAmdGpu(q->context->deviceId()))}, - bufConstWeights{q->context, std::move(weights.weightsConstIF)}, bufWeights{q->context, std::move(weights.weightsIF)}, bufBits{q->context, std::move(weights.bitsCF)}, - bufBitsC{q->context, std::move(weights.bitsC)}, +#endif #define BUF(name, ...) name{profile.make(#name), queue, __VA_ARGS__} BUF(bufData, N), BUF(bufAux, N), - BUF(bufCheck, N), - BUF(bufBase, N), // Every double-word (i.e. N/2) produces one carry. In addition we may have one extra group thus WIDTH more carries. - BUF(bufCarry, N / 2 + WIDTH), + BUF(bufCarry, N / 2 + WIDTH), BUF(bufReady, (N / 2 + WIDTH) / 32), // Every wavefront (32 or 64 lanes) needs to signal "carry is ready" BUF(bufSmallOut, 256), @@ -491,12 +629,9 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& BUF(bufROE, ROE_SIZE), BUF(bufStatsCarry, CARRY_SIZE), - // Allocate extra for padding. We can probably tighten up the amount of extra memory allocated. - // The worst case seems to be MIDDLE=4, PAD_SIZE=512 - #define total_padding (((pad_size == 0 ? 0 : (pad_size <= 128 ? N/8 : (pad_size <= 256 ? N/4 : N/2)))) * (fft.shape.middle == 4 ? 5 : 4) / 4) - BUF(buf1, N + total_padding), - BUF(buf2, N + total_padding), - BUF(buf3, N + total_padding), + BUF(buf1, TOTAL_DATA_SIZE(WIDTH, fft.shape.middle, SMALL_H, pad_size)), + BUF(buf2, TOTAL_DATA_SIZE(WIDTH, fft.shape.middle, SMALL_H, pad_size)), + BUF(buf3, TOTAL_DATA_SIZE(WIDTH, fft.shape.middle, SMALL_H, pad_size)), #undef BUF statsBits{u32(args.value("STATS", 0))}, @@ -514,41 +649,74 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& } } +#if !FFT_FP32 if (bitsPerWord < FFTShape::MIN_BPW) { log("FFT size too large for exponent (%.2f bits/word < %.2f bits/word).\n", bitsPerWord, FFTShape::MIN_BPW); throw "FFT size too large"; } +#endif useLongCarry = useLongCarry || (bitsPerWord < 12.0); if (useLongCarry) { log("Using long carry!\n"); } - + +#if FFT_FP64 | FFT_FP32 + kfftMidIn.setFixedArgs(2, bufTrigM); + kfftHin.setFixedArgs(2, bufTrigH); + ktailSquareZero.setFixedArgs(2, bufTrigH); + ktailSquare.setFixedArgs(2, bufTrigH); + ktailMulLow.setFixedArgs(3, bufTrigH); + ktailMul.setFixedArgs(3, bufTrigH); + kfftMidOut.setFixedArgs(2, bufTrigM); + kfftW.setFixedArgs(2, bufTrigW); +#endif + +#if NTT_GF31 + kfftMidInGF31.setFixedArgs(2, bufTrigM); + kfftHinGF31.setFixedArgs(2, bufTrigH); + ktailSquareZeroGF31.setFixedArgs(2, bufTrigH); + ktailSquareGF31.setFixedArgs(2, bufTrigH); + ktailMulLowGF31.setFixedArgs(3, bufTrigH); + ktailMulGF31.setFixedArgs(3, bufTrigH); + kfftMidOutGF31.setFixedArgs(2, bufTrigM); + kfftWGF31.setFixedArgs(2, bufTrigW); +#endif + +#if NTT_GF61 + kfftMidInGF61.setFixedArgs(2, bufTrigM); + kfftHinGF61.setFixedArgs(2, bufTrigH); + ktailSquareZeroGF61.setFixedArgs(2, bufTrigH); + ktailSquareGF61.setFixedArgs(2, bufTrigH); + ktailMulLowGF61.setFixedArgs(3, bufTrigH); + ktailMulGF61.setFixedArgs(3, bufTrigH); + kfftMidOutGF61.setFixedArgs(2, bufTrigM); + kfftWGF61.setFixedArgs(2, bufTrigW); +#endif + +#if FFT_FP64 | FFT_FP32 // The FP versions take bufWeight arguments (and bufBits which may be deleted) + kfftP.setFixedArgs(2, bufTrigW, bufWeights); + for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry, bufWeights); } + for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(5, bufStatsCarry); } + for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(5, bufROE); } for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { k->setFixedArgs(3, bufCarry, bufReady, bufTrigW, bufBits, bufConstWeights, bufWeights); } - for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(9, bufROE); } for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(9, bufStatsCarry); } - - for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { - k->setFixedArgs(3, bufCarry, bufBitsC, bufWeights); +#else + kfftP.setFixedArgs(2, bufTrigW); + for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry); } + for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(4, bufStatsCarry); } + for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(4, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { + k->setFixedArgs(3, bufCarry, bufReady, bufTrigW); } + for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(6, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(6, bufStatsCarry); } +#endif - for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(6, bufROE); } - for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(6, bufStatsCarry); } - - fftP.setFixedArgs(2, bufTrigW, bufWeights); - fftW.setFixedArgs(2, bufTrigW); - fftHin.setFixedArgs(2, bufTrigH); + carryB.setFixedArgs(1, bufCarry); - fftMidIn.setFixedArgs( 2, bufTrigM); - fftMidOut.setFixedArgs(2, bufTrigM); - - carryB.setFixedArgs(1, bufCarry, bufBitsC); - tailMulLow.setFixedArgs(3, bufTrigH); - tailMul.setFixedArgs(3, bufTrigH); - tailSquareZero.setFixedArgs(2, bufTrigH); - tailSquare.setFixedArgs(2, bufTrigH); kernIsEqual.setFixedArgs(2, bufTrue); bufReady.zero(); @@ -563,6 +731,142 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& queue->finish(); } + +// Call the appropriate kernels to support hybrid FFTs and NTTs + +void Gpu::fftP(Buffer& out, Buffer& in) { + kfftP(out, in); +} + +void Gpu::fftW(Buffer& out, Buffer& in) { +#if FFT_FP64 | FFT_FP32 + kfftW(out, in); +#endif +#if NTT_GF31 + kfftWGF31(out, in); +#endif +#if NTT_GF61 + kfftWGF61(out, in); +#endif +} + +void Gpu::fftMidIn(Buffer& out, Buffer& in) { +#if FFT_FP64 | FFT_FP32 + kfftMidIn(out, in); +#endif +#if NTT_GF31 + kfftMidInGF31(out, in); +#endif +#if NTT_GF61 + kfftMidInGF61(out, in); +#endif +} + +void Gpu::fftMidOut(Buffer& out, Buffer& in) { +#if FFT_FP64 | FFT_FP32 + kfftMidOut(out, in); +#endif +#if NTT_GF31 + kfftMidOutGF31(out, in); +#endif +#if NTT_GF61 + kfftMidOutGF61(out, in); +#endif +} + +void Gpu::fftHin(Buffer& out, Buffer& in) { +#if FFT_FP64 | FFT_FP32 + kfftHin(out, in); +#endif +#if NTT_GF31 + kfftHinGF31(out, in); +#endif +#if NTT_GF61 + kfftHinGF61(out, in); +#endif +} + +void Gpu::tailSquareZero(Buffer& out, Buffer& in) { +#if FFT_FP64 | FFT_FP32 + ktailSquareZero(out, in); +#endif +#if NTT_GF31 + ktailSquareZeroGF31(out, in); +#endif +#if NTT_GF61 + ktailSquareZeroGF61(out, in); +#endif +} + +void Gpu::tailSquare(Buffer& out, Buffer& in) { +#if FFT_FP64 | FFT_FP32 + ktailSquare(out, in); +#endif +#if NTT_GF31 + ktailSquareGF31(out, in); +#endif +#if NTT_GF61 + ktailSquareGF61(out, in); +#endif +} + +void Gpu::tailMul(Buffer& out, Buffer& in1, Buffer& in2) { +#if FFT_FP64 | FFT_FP32 + ktailMul(out, in1, in2); +#endif +#if NTT_GF31 + ktailMulGF31(out, in1, in2); +#endif +#if NTT_GF61 + ktailMulGF61(out, in1, in2); +#endif +} + +void Gpu::tailMulLow(Buffer& out, Buffer& in1, Buffer& in2) { +#if FFT_FP64 | FFT_FP32 + ktailMulLow(out, in1, in2); +#endif +#if NTT_GF31 + ktailMulLowGF31(out, in1, in2); +#endif +#if NTT_GF61 + ktailMulLowGF61(out, in1, in2); +#endif +} + +void Gpu::carryA(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryAROE(out, in, roePos++) + : kCarryA(out, in, updateCarryPos(1 << 2)); +} + +void Gpu::carryM(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryMROE(out, in, roePos++) + : kCarryM(out, in, updateCarryPos(1 << 3)); +} + +void Gpu::carryLL(Buffer& out, Buffer& in) { + kCarryLL(out, in, updateCarryPos(1 << 2)); +} + +void Gpu::carryFused(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryFusedROE(out, in, roePos++) + : kCarryFused(out, in, updateCarryPos(1 << 0)); +} + +void Gpu::carryFusedMul(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryFusedMulROE(out, in, roePos++) + : kCarryFusedMul(out, in, updateCarryPos(1 << 1)); +} + +void Gpu::carryFusedLL(Buffer& out, Buffer& in) { + kCarryFusedLL(out, in, updateCarryPos(1 << 0)); +} + + #if 0 void Gpu::measureTransferSpeed() { u32 SIZE_MB = 16; @@ -589,34 +893,8 @@ u32 Gpu::updateCarryPos(u32 bit) { return (statsBits & bit) && (carryPos < CARRY_SIZE) ? carryPos++ : carryPos; } -void Gpu::carryFused(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryFusedROE(a, b, roePos++) - : kCarryFused(a, b, updateCarryPos(1 << 0)); -} - -void Gpu::carryFusedMul(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryFusedMulROE(a, b, roePos++) - : kCarryFusedMul(a, b, updateCarryPos(1 << 1)); -} - -void Gpu::carryA(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryAROE(a, b, roePos++) - : kCarryA(a, b, updateCarryPos(1 << 2)); -} - -void Gpu::carryLL(Buffer& a, Buffer& b) { kCarryLL(a, b, updateCarryPos(1 << 2)); } - -void Gpu::carryM(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryMROE(a, b, roePos++) - : kCarryM(a, b, updateCarryPos(1 << 3)); -} - -vector> Gpu::makeBufVector(u32 size) { - vector> r; +vector> Gpu::makeBufVector(u32 size) { + vector> r; for (u32 i = 0; i < size; ++i) { r.emplace_back(timeBufVect, queue, N); } return r; } @@ -674,20 +952,23 @@ template static bool isAllZero(vector v) { return std::all_of(v.begin(), v.end(), [](T x) { return x == 0;}); } // Read from GPU, verifying the transfer with a sum, and retry on failure. -vector Gpu::readChecked(Buffer& buf) { +vector Gpu::readChecked(Buffer& buf) { for (int nRetry = 0; nRetry < 3; ++nRetry) { - sum64(bufSumOut, u32(buf.size * sizeof(int)), buf); + sum64(bufSumOut, u32(buf.size * sizeof(Word)), buf); vector expectedVect(1); bufSumOut.readAsync(expectedVect); - vector data = readOut(buf); + vector data = readOut(buf); u64 gpuSum = expectedVect[0]; u64 hostSum = 0; - for (auto it = data.begin(), end = data.end(); it < end; it += 2) { - hostSum += u32(*it) | (u64(*(it + 1)) << 32); + int even = 1; + for (auto it = data.begin(), end = data.end(); it < end; ++it, even = !even) { + if (sizeof(Word) == 4) hostSum += even ? u64(u32(*it)) : (u64(*it) << 32); + if (sizeof(Word) == 8) hostSum += u64(*it); + if (sizeof(Word) == 16) hostSum += u64(*it) + u64((__int128) *it >> 64); } if (hostSum == gpuSum) { @@ -704,12 +985,12 @@ vector Gpu::readChecked(Buffer& buf) { throw "GPU persistent read errors"; } -Words Gpu::readAndCompress(Buffer& buf) { return compactBits(readChecked(buf), E); } +Words Gpu::readAndCompress(Buffer& buf) { return compactBits(readChecked(buf), E); } vector Gpu::readCheck() { return readAndCompress(bufCheck); } vector Gpu::readData() { return readAndCompress(bufData); } // out := inA * inB; -void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { +void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { fftP(tmp1, ioA); fftMidIn(tmp2, tmp1); tailMul(tmp1, inB, tmp2); @@ -723,13 +1004,13 @@ void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffe carryB(ioA); } -void Gpu::mul(Buffer& io, Buffer& buf1) { +void Gpu::mul(Buffer& io, Buffer& buf1) { // We know that mul() stores double output in buf1; so we're going to use buf2 & buf3 for temps. mul(io, buf1, buf2, buf3, false); } // out := inA * inB; -void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { +void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { fftP(buf2, inB); fftMidIn(buf1, buf2); mul(ioA, buf1, buf2, buf3, mul3); @@ -741,25 +1022,25 @@ void Gpu::writeState(const vector& check, u32 blockSize) { bufData << bufCheck; bufAux << bufCheck; - - u32 n = 0; + + u32 n; for (n = 1; blockSize % (2 * n) == 0; n *= 2) { squareLoop(bufData, 0, n); modMul(bufData, bufAux); bufAux << bufData; } - + assert((n & (n - 1)) == 0); assert(blockSize % n == 0); - + blockSize /= n; assert(blockSize >= 2); - + for (u32 i = 0; i < blockSize - 2; ++i) { squareLoop(bufData, 0, n); modMul(bufData, bufAux); } - + squareLoop(bufData, 0, n); modMul(bufData, bufAux, true); } @@ -798,14 +1079,14 @@ void Gpu::logTimeKernels() { profile.reset(); } -vector Gpu::readOut(Buffer &buf) { +vector Gpu::readOut(Buffer &buf) { transpOut(bufAux, buf); return bufAux.read(); } -void Gpu::writeIn(Buffer& buf, const vector& words) { writeIn(buf, expandBits(words, N, E)); } +void Gpu::writeIn(Buffer& buf, const vector& words) { writeIn(buf, expandBits(words, N, E)); } -void Gpu::writeIn(Buffer& buf, vector&& words) { +void Gpu::writeIn(Buffer& buf, vector&& words) { bufAux.write(std::move(words)); transpIn(buf, bufAux); } @@ -831,7 +1112,7 @@ Words Gpu::expExp2(const Words& A, u32 n) { } // A:= A^h * B -void Gpu::expMul(Buffer& A, u64 h, Buffer& B) { +void Gpu::expMul(Buffer& A, u64 h, Buffer& B) { exponentiate(A, h, buf1, buf2, buf3); modMul(A, B); } @@ -856,7 +1137,7 @@ void Gpu::bottomHalf(Buffer& out, Buffer& inTmp) { } // See "left-to-right binary exponentiation" on wikipedia -void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3) { +void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3) { if (exp == 0) { bufInOut.set(1); } else if (exp > 1) { @@ -888,7 +1169,7 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buf } } -// does either carrryFused() or the expanded version depending on useLongCarry +// does either carryFused() or the expanded version depending on useLongCarry void Gpu::doCarry(Buffer& out, Buffer& in) { if (useLongCarry) { fftW(out, in); @@ -900,12 +1181,12 @@ void Gpu::doCarry(Buffer& out, Buffer& in) { } } -void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3, bool doLL) { +void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3, bool doLL) { // LL does not do Mul3 assert(!(doMul3 && doLL)); if (leadIn) { fftP(buf2, in); } - + bottomHalf(buf1, buf2); if (leadOut) { @@ -931,9 +1212,9 @@ void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, b } } -void Gpu::square(Buffer& io) { square(io, io, true, true, false, false); } +void Gpu::square(Buffer& io) { square(io, io, true, true, false, false); } -u32 Gpu::squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3) { +u32 Gpu::squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3) { assert(from < to); bool leadIn = true; for (u32 k = from; k < to; ++k) { @@ -944,17 +1225,17 @@ u32 Gpu::squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool do return to; } -bool Gpu::isEqual(Buffer& in1, Buffer& in2) { +bool Gpu::isEqual(Buffer& in1, Buffer& in2) { kernIsEqual(in1, in2); int isEq = 0; bufTrue.read(&isEq, 1); if (!isEq) { bufTrue.write({1}); } return isEq; } - -u64 Gpu::bufResidue(Buffer &buf) { + +u64 Gpu::bufResidue(Buffer &buf) { readResidue(bufSmallOut, buf); - int words[64]; + Word words[64]; bufSmallOut.read(words, 64); int carry = 0; @@ -964,10 +1245,10 @@ u64 Gpu::bufResidue(Buffer &buf) { int hasBits = 0; for (int k = 0; k < 32 && hasBits < 64; ++k) { u32 len = bitlen(N, E, k); - int w = words[32 + k] + carry; + Word w = words[32 + k] + carry; carry = (w < 0) ? -1 : 0; - if (w < 0) { w += (1 << len); } - assert(w >= 0 && w < (1 << len)); + if (w < 0) { w += (1LL << len); } + assert(w >= 0 && w < (1LL << len)); res |= u64(w) << hasBits; hasBits += len; } @@ -1015,8 +1296,8 @@ void Gpu::doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u3 auto [roeSq, roeMul] = readROE(); double z = roeSq.z(); zAvg.update(z, roeSq.N); - log("%sZ=%.0f (avg %.1f)%s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), - z, zAvg.avg(), (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); + log("%sZ=%.0f (avg %.1f), ROEmax=%.3f, ROEavg=%.3f, %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), + z, zAvg.avg(), roeSq.max, roeSq.mean, (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); if (roeSq.N > 2 && z < 20) { log("Danger ROE! Z=%.1f is too small, increase precision or FFT size!\n", z); @@ -1057,6 +1338,8 @@ int ulps(double a, double b) { } void Gpu::selftestTrig() { + +#if FFT_FP64 const u32 n = hN / 8; testTrig(buf1); vector trig = buf1.read(n * 2); @@ -1081,7 +1364,7 @@ void Gpu::selftestTrig() { if (c < refCos) { ++cdown; } double norm = trigNorm(c, s); - + if (norm < 1.0) { ++oneDown; } if (norm > 1.0) { ++oneUp; } } @@ -1091,7 +1374,8 @@ void Gpu::selftestTrig() { log("TRIG cos(): imperfect %d / %d (%.2f%%), balance %d\n", cup + cdown, n, (cup + cdown) * 100.0 / n, cup - cdown); log("TRIG norm: up %d, down %d\n", oneUp, oneDown); - +#endif + if (isAmdGpu(queue->context->deviceId())) { vector WHATS {"V_NOP", "V_ADD_I32", "V_FMA_F32", "V_ADD_F64", "V_FMA_F64", "V_MUL_F64", "V_MAD_U64_U32"}; for (int w = 0; w < int(WHATS.size()); ++w) { @@ -1474,15 +1758,17 @@ PRPResult Gpu::isPrimePRP(const Task& task) { bool doStop = (k % blockSize == 0) && (Signal::stopRequested() || (args.iters && k - startK >= args.iters)); bool leadOut = (k % blockSize == 0) || k == persistK || k == kEnd || useLongCarry; +//if (k%10==0) leadOut = true; //GWBUG assert(!doStop || leadOut); if (doStop) { log("Stopping, please wait..\n"); } square(bufData, bufData, leadIn, leadOut, false); +//if(leadOut)printf("k %d, Residue: %lX\n", k, bufResidue(bufData)); //GWBUG leadIn = leadOut; - + if (k == persistK) { - vector rawData = readChecked(bufData); + vector rawData = readChecked(bufData); if (rawData.empty()) { log("Data error ZERO\n"); ++nErrors; @@ -1514,7 +1800,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { float secsPerIt = iterationTimer.reset(k); queue->setSquareTime((int) (secsPerIt * 1'000'000)); - vector rawCheck = readChecked(bufCheck); + vector rawCheck = readChecked(bufCheck); if (rawCheck.empty()) { ++nErrors; log("%9u %016" PRIx64 " read NULL check\n", k, res); @@ -1577,7 +1863,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { queue->finish(); throw "stop requested"; } - + iterationTimer.reset(k); } } diff --git a/src/Gpu.h b/src/Gpu.h index d232157b..b5b424c8 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -26,12 +26,9 @@ struct Task; class Signal; class ProofSet; -using double2 = pair; using TrigBuf = Buffer; using TrigPtr = shared_ptr; -namespace fs = std::filesystem; - inline u64 residue(const Words& words) { return (u64(words[1]) << 32) | words[0]; } struct PRPResult { @@ -80,12 +77,20 @@ class RoeInfo { double gumbelMiu{}, gumbelBeta{}; }; +#if FFT_FP64 struct Weights { vector weightsConstIF; vector weightsIF; vector bitsCF; - vector bitsC; }; +#endif +#if FFT_FP32 +struct Weights { + vector weightsConstIF; + vector weightsIF; + vector bitsCF; +}; +#endif class Gpu { Queue* queue; @@ -104,50 +109,71 @@ class Gpu { u32 SMALL_H; u32 BIG_H; - u32 hN, nW, nH, bufSize; + u32 hN, nW, nH; bool useLongCarry; u32 wantROE{}; Profile profile{}; KernelCompiler compiler; - - Kernel kCarryFused; - Kernel kCarryFusedROE; - Kernel kCarryFusedMul; - Kernel kCarryFusedMulROE; - Kernel kCarryFusedLL; +#if FFT_FP64 | FFT_FP32 + Kernel kfftMidIn; + Kernel kfftHin; + Kernel ktailSquareZero; + Kernel ktailSquare; + Kernel ktailMul; + Kernel ktailMulLow; + Kernel kfftMidOut; + Kernel kfftW; +#endif + +#if NTT_GF31 + Kernel kfftMidInGF31; + Kernel kfftHinGF31; + Kernel ktailSquareZeroGF31; + Kernel ktailSquareGF31; + Kernel ktailMulGF31; + Kernel ktailMulLowGF31; + Kernel kfftMidOutGF31; + Kernel kfftWGF31; +#endif + +#if NTT_GF61 + Kernel kfftMidInGF61; + Kernel kfftHinGF61; + Kernel ktailSquareZeroGF61; + Kernel ktailSquareGF61; + Kernel ktailMulGF61; + Kernel ktailMulLowGF61; + Kernel kfftMidOutGF61; + Kernel kfftWGF61; +#endif + + Kernel kfftP; Kernel kCarryA; Kernel kCarryAROE; Kernel kCarryM; Kernel kCarryMROE; Kernel kCarryLL; - Kernel carryB; - - Kernel fftP; - Kernel fftW; - - Kernel fftHin; - - Kernel tailSquareZero; - Kernel tailSquare; - Kernel tailMul; - Kernel tailMulLow; - - Kernel fftMidIn; - Kernel fftMidOut; + Kernel kCarryFused; + Kernel kCarryFusedROE; + Kernel kCarryFusedMul; + Kernel kCarryFusedMulROE; + Kernel kCarryFusedLL; + Kernel carryB; Kernel transpIn, transpOut; - Kernel readResidue; Kernel kernIsEqual; Kernel sum64; +#if FFT_FP64 Kernel testTrig; Kernel testFFT4; - Kernel testFFT; - Kernel testFFT15; Kernel testFFT14; + Kernel testFFT15; + Kernel testFFT; +#endif Kernel testTime; // Kernel testKernel; @@ -156,39 +182,42 @@ class Gpu { bool tail_single_wide; // TailSquare processes one line at a time bool tail_single_kernel; // TailSquare does not use a separate kernel for line zero u32 tail_trigs; // 0,1,2. Increasing values use more DP and less memory accesses - u32 pad_size; // Pad size in bytes + u32 pad_size; // Pad size in bytes as specified on the command line or config.txt. Maximum value is 512. // Twiddles: trigonometry constant buffers, used in FFTs. // The twiddles depend only on FFT config and do not depend on the exponent. - TrigPtr bufTrigW; + // It is important to generate the height trigs before the width trigs because width trigs can be a subset of the height trigs TrigPtr bufTrigH; TrigPtr bufTrigM; - - Weights weights; + TrigPtr bufTrigW; // The weights and the "bigWord bits" depend on the exponent. +#if FFT_FP64 + Weights weights; Buffer bufConstWeights; Buffer bufWeights; - Buffer bufBits; // bigWord bits aligned for CarryFused/fftP - Buffer bufBitsC; // bigWord bits aligned for CarryA/M +#endif +#if FFT_FP32 + Weights weights; + Buffer bufConstWeights; + Buffer bufWeights; + Buffer bufBits; // bigWord bits aligned for CarryFused/fftP +#endif // "integer word" buffers. These are "small buffers": N x int. - Buffer bufData; // Main int buffer with the words. - Buffer bufAux; // Auxiliary int buffer, used in transposing data in/out and in check. - Buffer bufCheck; // Buffers used with the error check. - Buffer bufBase; // used in P-1 error check. + Buffer bufData; // Main int buffer with the words. + Buffer bufAux; // Auxiliary int buffer, used in transposing data in/out and in check. + Buffer bufCheck; // Buffers used with the error check. // Carry buffers, used in carry and fusedCarry. Buffer bufCarry; // Carry shuttle. - Buffer bufReady; // Per-group ready flag for stairway carry propagation. // Small aux buffers. - Buffer bufSmallOut; + Buffer bufSmallOut; Buffer bufSumOut; Buffer bufTrue; - Buffer bufROE; // The round-off error ("ROE"), one float element per iteration. Buffer bufStatsCarry; @@ -207,46 +236,64 @@ class Gpu { TimeInfo* timeBufVect; ZAvg zAvg; - vector readOut(Buffer &buf); - void writeIn(Buffer& buf, vector&& words); - - void square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3 = false, bool doLL = false); - void squareCERT(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, false); } - void squareLL(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, true); } - - void square(Buffer& io); - - u32 squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3); - u32 squareLoop(Buffer& io, u32 from, u32 to) { return squareLoop(io, io, from, to, false); } - - bool isEqual(Buffer& bufCheck, Buffer& bufAux); - u64 bufResidue(Buffer& buf); + void fftP(Buffer& out, Buffer& in) { fftP(out, reinterpret_cast&>(in)); } + void fftP(Buffer& out, Buffer& in); + void fftMidIn(Buffer& out, Buffer& in); + void fftMidOut(Buffer& out, Buffer& in); + void fftHin(Buffer& out, Buffer& in); + void tailSquareZero(Buffer& out, Buffer& in); + void tailSquare(Buffer& out, Buffer& in); + void tailMul(Buffer& out, Buffer& in1, Buffer& in2); + void tailMulLow(Buffer& out, Buffer& in1, Buffer& in2); + void fftW(Buffer& out, Buffer& in); + void carryA(Buffer& out, Buffer& in) { carryA(reinterpret_cast&>(out), in); } + void carryA(Buffer& out, Buffer& in); + void carryM(Buffer& out, Buffer& in); + void carryLL(Buffer& out, Buffer& in); + void carryFused(Buffer& out, Buffer& in); + void carryFusedMul(Buffer& out, Buffer& in); + void carryFusedLL(Buffer& out, Buffer& in); + + vector readOut(Buffer &buf); + void writeIn(Buffer& buf, vector&& words); + + void square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3 = false, bool doLL = false); + void squareCERT(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, false); } + void squareLL(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, true); } + + void square(Buffer& io); + + u32 squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3); + u32 squareLoop(Buffer& io, u32 from, u32 to) { return squareLoop(io, io, from, to, false); } + + bool isEqual(Buffer& bufCheck, Buffer& bufAux); + u64 bufResidue(Buffer& buf); vector writeBase(const vector &v); - void exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3); + void exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3); void bottomHalf(Buffer& out, Buffer& inTmp); void writeState(const vector& check, u32 blockSize); - + // does either carrryFused() or the expanded version depending on useLongCarry void doCarry(Buffer& out, Buffer& in); - void mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3 = false); - void mul(Buffer& io, Buffer& inB); + void mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3 = false); + void mul(Buffer& io, Buffer& inB); + + void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); - void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); - fs::path saveProof(const Args& args, const ProofSet& proofSet); std::pair readROE(); RoeInfo readCarryStats(); - + u32 updateCarryPos(u32 bit); PRPState loadPRP(Saver& saver); - vector readChecked(Buffer& buf); + vector readChecked(Buffer& buf); // void measureTransferSpeed(); @@ -272,21 +319,7 @@ class Gpu { Saver *getSaver(); - void carryA(Buffer& a, Buffer& b) { carryA(reinterpret_cast&>(a), b); } - - void carryA(Buffer& a, Buffer& b); - - void carryM(Buffer& a, Buffer& b); - - void carryLL(Buffer& a, Buffer& b); - - void carryFused(Buffer& a, Buffer& b); - - void carryFusedMul(Buffer& a, Buffer& b); - - void carryFusedLL(Buffer& a, Buffer& b) { kCarryFusedLL(a, b, updateCarryPos(1<<0));} - - void writeIn(Buffer& buf, const vector &words); + void writeIn(Buffer& buf, const vector &words); u64 dataResidue() { return bufResidue(bufData); } u64 checkResidue() { return bufResidue(bufCheck); } @@ -295,7 +328,7 @@ class Gpu { void logTimeKernels(); - Words readAndCompress(Buffer& buf); + Words readAndCompress(Buffer& buf); vector readCheck(); vector readData(); @@ -309,11 +342,11 @@ class Gpu { Words expMul2(const Words& A, u64 h, const Words& B); // A:= A^h * B - void expMul(Buffer& A, u64 h, Buffer& B); + void expMul(Buffer& A, u64 h, Buffer& B); // return A^(2^n) Words expExp2(const Words& A, u32 n); - vector> makeBufVector(u32 size); + vector> makeBufVector(u32 size); void clear(bool isPRP); @@ -321,3 +354,15 @@ class Gpu { u32 getProofPower(u32 k); void doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u32 nErrors); }; + +// Compute the size of an FFT/NTT data buffer depending on the FFT/NTT float/prime. Size is returned in units of sizeof(double). +// Data buffers require extra space for padding. We can probably tighten up the amount of extra memory allocated. +// The worst case seems to be MIDDLE=4, PAD_SIZE=512. + +#define MID_ADJUST(size,M,pad) ((pad == 0 || M != 4) ? (size) : (size) * 5/4) +#define PAD_ADJUST(N,M,pad) MID_ADJUST(pad == 0 ? N : pad <= 128 ? 9*N/8 : pad <= 256 ? 5*N/4 : 3*N/2, M, pad) +#define FP64_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) +#define FP32_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(float) / sizeof(double) +#define GF31_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(uint) / sizeof(double) +#define GF61_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(ulong) / sizeof(double) +#define TOTAL_DATA_SIZE(W,M,H,pad) FFT_FP64 * FP64_DATA_SIZE(W,M,H,pad) + FFT_FP32 * FP32_DATA_SIZE(W,M,H,pad) + NTT_GF31 * GF31_DATA_SIZE(W,M,H,pad) + NTT_GF61 * GF61_DATA_SIZE(W,M,H,pad) diff --git a/src/Proof.cpp b/src/Proof.cpp index 4f1d01e1..31c488f8 100644 --- a/src/Proof.cpp +++ b/src/Proof.cpp @@ -268,7 +268,7 @@ std::pair> ProofSet::computeProof(Gpu *gpu) const { auto hash = proof::hashWords(E, B); - vector> bufVect = gpu->makeBufVector(power); + vector> bufVect = gpu->makeBufVector(power); for (u32 p = 0; p < power; ++p) { auto bufIt = bufVect.begin(); diff --git a/src/Queue.cpp b/src/Queue.cpp index d2c1ceac..829be7ae 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -58,7 +58,6 @@ void Queue::add(EventHolder&& e, TimeInfo* ti) { } void Queue::readSync(cl_mem buf, u32 size, void* out, TimeInfo* tInfo) { - queueMarkerEvent(); add(read(get(), {}, true, buf, size, out, hasEvents), tInfo); events.synced(); } diff --git a/src/Task.cpp b/src/Task.cpp index db36e782..7e148181 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -40,7 +40,7 @@ constexpr int platform() { const constexpr bool IS_32BIT = (sizeof(void*) == 4); -#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) || defined(__MINGW32__) || defined(__MINGW64__) return IS_32BIT ? WIN_32 : WIN_64; #elif __APPLE__ @@ -83,6 +83,24 @@ OsInfo getOsInfo() { return getOsInfoMinimum(); } #endif +#if FFT_FP64 & NTT_GF31 +#define JSON_FFT_TYPE json("fft-type", "FP64+M31") +#elif FFT_FP64 +#define JSON_FFT_TYPE json("fft-type", "FP64") +#elif FFT_FP32 & NTT_GF31 +#define JSON_FFT_TYPE json("fft-type", "FP32+M31") +#elif FFT_FP32 & NTT_GF61 +#define JSON_FFT_TYPE json("fft-type", "FP32+M61") +#elif NTT_GF31 & NTT_GF61 +#define JSON_FFT_TYPE json("fft-type", "M31+M61") +#elif FFT_FP32 +#define JSON_FFT_TYPE json("fft-type", "FP32") +#elif NTT_GF31 +#define JSON_FFT_TYPE json("fft-type", "M31") +#elif NTT_GF61 +#define JSON_FFT_TYPE json("fft-type", "M61") +#endif + string json(const vector& v) { bool isFirst = true; string s = "{"; @@ -150,6 +168,7 @@ void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res6 json("res2048", res2048), json("residue-type", 1), json("errors", vector{json("gerbicz", nErrors)}), + JSON_FFT_TYPE, json("fft-length", fftSize) }; @@ -171,6 +190,7 @@ void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res6 void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64, u32 fftSize) const { vector fields{json("res64", hex(res64)), + JSON_FFT_TYPE, json("fft-length", fftSize), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this @@ -182,9 +202,10 @@ void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64 void Task::writeResultCERT(const Args &args, u32 instance, array hash, u32 squarings, u32 fftSize) const { string hexhash = hex(hash[3]) + hex(hash[2]) + hex(hash[1]) + hex(hash[0]); vector fields{json("worktype", "Cert"), - json("exponent", exponent), - json("sha3-hash", hexhash.c_str()), - json("squarings", squarings), + json("exponent", exponent), + json("sha3-hash", hexhash.c_str()), + json("squarings", squarings), + JSON_FFT_TYPE, json("fft-length", fftSize), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this diff --git a/src/TrigBufCache.cpp b/src/TrigBufCache.cpp index 7646a724..734b9487 100644 --- a/src/TrigBufCache.cpp +++ b/src/TrigBufCache.cpp @@ -1,9 +1,12 @@ - // Copyright Mihai Preda +// Copyright Mihai Preda +#include #include "TrigBufCache.h" +#if FFT_FP64 + #define SAVE_ONE_MORE_WIDTH_MUL 0 // I want to make saving the only option -- but rocm optimizer is inexplicably making it slower in carryfused -#define SAVE_ONE_MORE_HEIGHT_MUL 1 // In tailSquar this is the fastest option +#define SAVE_ONE_MORE_HEIGHT_MUL 1 // In tailSquare this is the fastest option #define _USE_MATH_DEFINES #include @@ -119,7 +122,6 @@ double root1cosover(u32 N, u32 k, double over) { return double(c / over); } -namespace { static const constexpr bool LOG_TRIG_ALLOC = false; // Interleave two lines of trig values so that AMD GPUs can use global_load_dwordx4 instructions @@ -136,8 +138,8 @@ void T2shuffle(u32 size, u32 radix, u32 line, vector &tab) { } } -vector genSmallTrig(u32 size, u32 radix) { - if (LOG_TRIG_ALLOC) { log("genSmallTrig(%u, %u)\n", size, radix); } +vector genSmallTrigFP64(u32 size, u32 radix) { + if (LOG_TRIG_ALLOC) { log("genSmallTrigFP64(%u, %u)\n", size, radix); } u32 WG = size / radix; vector tab; @@ -221,10 +223,10 @@ vector genSmallTrig(u32 size, u32 radix) { } // Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. -vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { - if (LOG_TRIG_ALLOC) { log("genSmallTrigCombo(%u, %u)\n", size, radix); } +vector genSmallTrigComboFP64(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { + if (LOG_TRIG_ALLOC) { log("genSmallTrigComboFP64(%u, %u)\n", size, radix); } - vector tab = genSmallTrig(size, radix); + vector tab = genSmallTrigFP64(size, radix); // From tailSquare pre-calculate some or all of these: T2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); if (tail_trigs == 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. @@ -234,14 +236,14 @@ vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bo tab.push_back(root1(width * middle * height, width * middle * me)); } // Output the one or two T2 multipliers to be read by one u,v pair of lines - for (u32 line = 0; line < width * middle / 2; ++line) { + for (u32 line = 0; line <= width * middle / 2; ++line) { tab.push_back(root1Fancy(width * middle * height, line)); if (!tail_single_wide) tab.push_back(root1Fancy(width * middle * height, line ? width * middle - line : width * middle / 2)); } } if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with lousy DP performance. u32 height = size; - for (u32 u = 0; u < width * middle / 2; ++u) { + for (u32 u = 0; u <= width * middle / 2; ++u) { for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); for (u32 me = 0; me < height / radix; ++me) { @@ -258,8 +260,8 @@ vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bo // cos-1 "fancy" trick. #define SHARP_MIDDLE 5 -vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { - if (LOG_TRIG_ALLOC) { log("genMiddleTrig(%u, %u, %u)\n", smallH, middle, width); } +vector genMiddleTrigFP64(u32 smallH, u32 middle, u32 width) { + if (LOG_TRIG_ALLOC) { log("genMiddleTrigFP64(%u, %u, %u)\n", smallH, middle, width); } vector tab; if (middle == 1) { tab.resize(1); @@ -267,49 +269,649 @@ vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { if (middle < SHARP_MIDDLE) { for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(smallH * middle, k)); } for (u32 k = 0; k < width; ++k) { tab.push_back(root1(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } } else { for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1Fancy(smallH * middle, k)); } for (u32 k = 0; k < width; ++k) { tab.push_back(root1Fancy(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } } } return tab; } -} // namespace +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on floats */ +/**************************************************************************/ + +#if FFT_FP32 + +#define _USE_MATH_DEFINES +#include + +#ifndef M_PI +#define M_PI 3.1415926535897931 +#endif + +// For small angles, return "fancy" cos - 1 for increased precision +float2 root1Fancy(u32 N, u32 k) { + assert(!(N&7)); + assert(k < N); + assert(k < N/4); + + double angle = M_PI * k / (N / 2); + return {float(cos(angle) - 1), float(sin(angle))}; +} + +static float trigNorm(float c, float s) { return c * c + s * s; } +static float trigError(float c, float s) { return abs(trigNorm(c, s) - 1.0f); } + +// Round trig double to float as to satisfy c^2 + s^2 == 1 as best as possible +static float2 roundTrig(double lc, double ls) { + float c1 = lc; + float c2 = nexttoward(c1, lc); + float s1 = ls; + float s2 = nexttoward(s1, ls); + + float c = c1; + float s = s1; + for (float tryC : {c1, c2}) { + for (float tryS : {s1, s2}) { + if (trigError(tryC, tryS) < trigError(c, s)) { + c = tryC; + s = tryS; + } + } + } + return {c, s}; +} + +// Returns the primitive root of unity of order N, to the power k. +float2 root1(u32 N, u32 k) { + assert(k < N); + if (k >= N/2) { + auto [c, s] = root1(N, k - N/2); + return {-c, -s}; + } else if (k > N/4) { + auto [c, s] = root1(N, N/2 - k); + return {-c, s}; + } else if (k > N/8) { + auto [c, s] = root1(N, N/4 - k); + return {s, c}; + } else { + assert(k <= N/8); + + double angle = M_PI * k / (N / 2); + return roundTrig(cos(angle), sin(angle)); + } +} + +vector genSmallTrigFP32(u32 size, u32 radix) { + u32 WG = size / radix; + vector tab; + +// old fft_WIDTH and fft_HEIGHT + for (u32 line = 1; line < radix; ++line) { + for (u32 col = 0; col < WG; ++col) { + tab.push_back(radix / line >= 8 ? root1Fancy(size, col * line) : root1(size, col * line)); + } + } + tab.resize(size); + return tab; +} + +// Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. +vector genSmallTrigComboFP32(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { + vector tab = genSmallTrigFP32(size, radix); + + // From tailSquare pre-calculate some or all of these: F2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); + if (tail_trigs == 1) { // Some trig values in memory, some are computed with a complex multiply. + u32 height = size; + // Output line 0 trig values to be read by every u,v pair of lines + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1(width * middle * height, width * middle * me)); + } + // Output the one or two F2 multipliers to be read by one u,v pair of lines + for (u32 line = 0; line <= width * middle / 2; ++line) { + tab.push_back(root1Fancy(width * middle * height, line)); + if (!tail_single_wide) tab.push_back(root1Fancy(width * middle * height, line ? width * middle - line : width * middle / 2)); + } + } + if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with lousy FP performance? + u32 height = size; + for (u32 u = 0; u <= width * middle / 2; ++u) { + for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { + u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1(width * middle * height, line + width * middle * me)); + } + } + } + } + + return tab; +} + +// starting from a MIDDLE of 5 we consider angles in [0, 2Pi/MIDDLE] as worth storing with the +// cos-1 "fancy" trick. +#define SHARP_MIDDLE 5 + +vector genMiddleTrigFP32(u32 smallH, u32 middle, u32 width) { + vector tab; + if (middle == 1) { + tab.resize(1); + } else { + if (middle < SHARP_MIDDLE) { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(smallH * middle, k)); } + for (u32 k = 0; k < width; ++k) { tab.push_back(root1(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } + } else { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1Fancy(smallH * middle, k)); } + for (u32 k = 0; k < width; ++k) { tab.push_back(root1Fancy(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } + } + } + return tab; +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Z31 and GF31 code copied from Yves Gallot's mersenne2 program + +// Z/{2^31 - 1}Z: the prime field of order p = 2^31 - 1 +class Z31 +{ +private: + static const uint32_t _p = (uint32_t(1) << 31) - 1; + uint32_t _n; // 0 <= n < p + + static uint32_t _add(const uint32_t a, const uint32_t b) + { + const uint32_t t = a + b; + return t - ((t >= _p) ? _p : 0); + } + + static uint32_t _sub(const uint32_t a, const uint32_t b) + { + const uint32_t t = a - b; + return t + ((a < b) ? _p : 0); + } + + static uint32_t _mul(const uint32_t a, const uint32_t b) + { + const uint64_t t = a * uint64_t(b); + return _add(uint32_t(t) & _p, uint32_t(t >> 31)); + } + +public: + Z31() {} + explicit Z31(const uint32_t n) : _n(n) {} + + uint32_t get() const { return _n; } + + bool operator!=(const Z31 & rhs) const { return (_n != rhs._n); } + + // Z31 neg() const { return Z31((_n == 0) ? 0 : _p - _n); } + // Z31 half() const { return Z31(((_n % 2 == 0) ? _n : (_n + _p)) / 2); } + + Z31 operator+(const Z31 & rhs) const { return Z31(_add(_n, rhs._n)); } + Z31 operator-(const Z31 & rhs) const { return Z31(_sub(_n, rhs._n)); } + Z31 operator*(const Z31 & rhs) const { return Z31(_mul(_n, rhs._n)); } + + Z31 sqr() const { return Z31(_mul(_n, _n)); } +}; + + +// GF((2^31 - 1)^2): the prime field of order p^2, p = 2^31 - 1 +class GF31 +{ +private: + Z31 _s0, _s1; + // a primitive root of order 2^32 which is a root of (0, 1). + static const uint64_t _h_order = uint64_t(1) << 32; + static const uint32_t _h_0 = 7735u, _h_1 = 748621u; + +public: + GF31() {} + explicit GF31(const Z31 & s0, const Z31 & s1) : _s0(s0), _s1(s1) {} + explicit GF31(const uint32_t n0, const uint32_t n1) : _s0(n0), _s1(n1) {} + + const Z31 & s0() const { return _s0; } + const Z31 & s1() const { return _s1; } + + GF31 operator+(const GF31 & rhs) const { return GF31(_s0 + rhs._s0, _s1 + rhs._s1); } + GF31 operator-(const GF31 & rhs) const { return GF31(_s0 - rhs._s0, _s1 - rhs._s1); } + + GF31 sqr() const { const Z31 t = _s0 * _s1; return GF31(_s0.sqr() - _s1.sqr(), t + t); } + GF31 mul(const GF31 & rhs) const { return GF31(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } + + GF31 pow(const uint64_t e) const + { + if (e == 0) return GF31(1u, 0u); + GF31 r = GF31(1u, 0u), y = *this; + for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } + return r.mul(y); + } + + static const GF31 root_one(const size_t n) { return GF31(Z31(_h_0), Z31(_h_1)).pow(_h_order / n); } + static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 30) / n) % 31); } +}; + +// Returns the primitive root of unity of order N, to the power k. +uint2 root1GF31(GF31 root1N, u32 k) { + GF31 x = root1N.pow(k); + return { x.s0().get(), x.s1().get() }; +} +uint2 root1GF31(u32 N, u32 k) { + assert(k < N); + GF31 root1N = GF31::root_one(N); + return root1GF31(root1N, k); +} + +vector genSmallTrigGF31(u32 size, u32 radix) { + u32 WG = size / radix; + vector tab; + +// old fft_WIDTH and fft_HEIGHT + GF31 root1size = GF31::root_one(size); + for (u32 line = 1; line < radix; ++line) { + for (u32 col = 0; col < WG; ++col) { + tab.push_back(root1GF31(root1size, col * line)); + } + } + tab.resize(size); + return tab; +} + +// Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. +vector genSmallTrigComboGF31(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { + vector tab = genSmallTrigGF31(size, radix); + + // From tailSquareGF31 pre-calculate some or all of these: GF31 trig = slowTrigGF31(line + H * lowMe, ND / NH * 2); + u32 height = size; + GF31 root1wmh = GF31::root_one(width * middle * height); + if (tail_trigs >= 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. + // Output line 0 trig values to be read by every u,v pair of lines + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF31(root1wmh, width * middle * me)); + } + // Output the one or two GF31 multipliers to be read by one u,v pair of lines + for (u32 line = 0; line <= width * middle / 2; ++line) { + tab.push_back(root1GF31(root1wmh, line)); + if (!tail_single_wide) tab.push_back(root1GF31(root1wmh, line ? width * middle - line : width * middle / 2)); + } + } + if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with great memory performance. + for (u32 u = 0; u <= width * middle / 2; ++u) { + for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { + u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF31(root1wmh, line + width * middle * me)); + } + } + } + } + + return tab; +} + +vector genMiddleTrigGF31(u32 smallH, u32 middle, u32 width) { + vector tab; + if (middle == 1) { + tab.resize(1); + } else { + GF31 root1hm = GF31::root_one(smallH * middle); + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF31(root1hm, k)); } + GF31 root1mw = GF31::root_one(middle * width); + for (u32 k = 0; k < width; ++k) { tab.push_back(root1GF31(root1mw, k)); } + GF31 root1wmh = GF31::root_one(width * middle * smallH); + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF31(root1wmh, k)); } + } + return tab; +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Z61 and GF61 code copied from Yves Gallot's mersenne2 program + +// Z/{2^61 - 1}Z: the prime field of order p = 2^61 - 1 +class Z61 +{ +private: + static const uint64_t _p = (uint64_t(1) << 61) - 1; + uint64_t _n; // 0 <= n < p + + static uint64_t _add(const uint64_t a, const uint64_t b) + { + const uint64_t t = a + b; + return t - ((t >= _p) ? _p : 0); + } + + static uint64_t _sub(const uint64_t a, const uint64_t b) + { + const uint64_t t = a - b; + return t + ((a < b) ? _p : 0); + } + + static uint64_t _mul(const uint64_t a, const uint64_t b) + { + const __uint128_t t = a * __uint128_t(b); + const uint64_t lo = uint64_t(t), hi = uint64_t(t >> 64); + const uint64_t lo61 = lo & _p, hi61 = (lo >> 61) | (hi << 3); + return _add(lo61, hi61); + } + +public: + Z61() {} + explicit Z61(const uint64_t n) : _n(n) {} + + uint64_t get() const { return _n; } + + bool operator!=(const Z61 & rhs) const { return (_n != rhs._n); } + + Z61 operator+(const Z61 & rhs) const { return Z61(_add(_n, rhs._n)); } + Z61 operator-(const Z61 & rhs) const { return Z61(_sub(_n, rhs._n)); } + Z61 operator*(const Z61 & rhs) const { return Z61(_mul(_n, rhs._n)); } + + Z61 sqr() const { return Z61(_mul(_n, _n)); } +}; + +// GF((2^61 - 1)^2): the prime field of order p^2, p = 2^61 - 1 +class GF61 +{ +private: + Z61 _s0, _s1; + // Primitive root of order 2^62 which is a root of (0, 1). This root corresponds to 2*pi*i*j/N in FFTs. PRPLL FFTs use this root. Thanks, Yves! + static const uint64_t _h_0 = 264036120304204ull, _h_1 = 4677669021635377ull; + // Primitive root of order 2^62 which is a root of (0, -1). This root corresponds to -2*pi*i*j/N in FFTs. + //static const uint64_t _h_0 = 481139922016222ull, _h_1 = 814659809902011ull; + static const uint64_t _h_order = uint64_t(1) << 62; + +public: + GF61() {} + explicit GF61(const Z61 & s0, const Z61 & s1) : _s0(s0), _s1(s1) {} + explicit GF61(const uint64_t n0, const uint64_t n1) : _s0(n0), _s1(n1) {} + + const Z61 & s0() const { return _s0; } + const Z61 & s1() const { return _s1; } + + GF61 operator+(const GF61 & rhs) const { return GF61(_s0 + rhs._s0, _s1 + rhs._s1); } + GF61 operator-(const GF61 & rhs) const { return GF61(_s0 - rhs._s0, _s1 - rhs._s1); } + + GF61 sqr() const { const Z61 t = _s0 * _s1; return GF61(_s0.sqr() - _s1.sqr(), t + t); } + GF61 mul(const GF61 & rhs) const { return GF61(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } + + GF61 pow(const uint64_t e) const + { + if (e == 0) return GF61(1u, 0u); + GF61 r = GF61(1u, 0u), y = *this; + for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } + return r.mul(y); + } + + static const GF61 root_one(const size_t n) { return GF61(Z61(_h_0), Z61(_h_1)).pow(_h_order / n); } + static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 60) / n) % 61); } +}; + +// Returns the primitive root of unity of order N, to the power k. +ulong2 root1GF61(GF61 root1N, u32 k) { + GF61 x = root1N.pow(k); + return { x.s0().get(), x.s1().get() }; +} +ulong2 root1GF61(u32 N, u32 k) { + assert(k < N); + GF61 root1N = GF61::root_one(N); + return root1GF61(root1N, k); +} + +vector genSmallTrigGF61(u32 size, u32 radix) { + u32 WG = size / radix; + vector tab; + +// old fft_WIDTH and fft_HEIGHT + GF61 root1size = GF61::root_one(size); + for (u32 line = 1; line < radix; ++line) { + for (u32 col = 0; col < WG; ++col) { + tab.push_back(root1GF61(root1size, col * line)); + } + } + tab.resize(size); + return tab; +} + +// Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. +vector genSmallTrigComboGF61(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { + vector tab = genSmallTrigGF61(size, radix); + + // From tailSquareGF61 pre-calculate some or all of these: GF61 trig = slowTrigGF61(line + H * lowMe, ND / NH * 2); + u32 height = size; + GF61 root1wmh = GF61::root_one(width * middle * height); + if (tail_trigs >= 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. + // Output line 0 trig values to be read by every u,v pair of lines + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF61(root1wmh, width * middle * me)); + } + // Output the one or two GF61 multipliers to be read by one u,v pair of lines + for (u32 line = 0; line <= width * middle / 2; ++line) { + tab.push_back(root1GF61(root1wmh, line)); + if (!tail_single_wide) tab.push_back(root1GF61(root1wmh, line ? width * middle - line : width * middle / 2)); + } + } + if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with great memory performance. + for (u32 u = 0; u <= width * middle / 2; ++u) { + for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { + u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF61(root1wmh, line + width * middle * me)); + } + } + } + } + + return tab; +} + +vector genMiddleTrigGF61(u32 smallH, u32 middle, u32 width) { + vector tab; + if (middle == 1) { + tab.resize(1); + } else { + GF61 root1hm = GF61::root_one(smallH * middle); + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF61(root1hm, k)); } + GF61 root1mw = GF61::root_one(middle * width); + for (u32 k = 0; k < width; ++k) { tab.push_back(root1GF61(root1mw, k)); } + GF61 root1wmh = GF61::root_one(width * middle * smallH); + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF61(root1wmh, k)); } + } + return tab; +} + +#endif + + +/**********************************************************/ +/* Build all the needed trig values into one big buffer */ +/**********************************************************/ + +vector genSmallTrig(u32 size, u32 radix) { + vector tab; + u32 tabsize; + +#if FFT_FP64 + tab = genSmallTrigFP64(size, radix); + tab.resize(SMALLTRIG_FP64_SIZE(size, 0, 0, 0)); +#endif + +#if FFT_FP32 + vector tab1 = genSmallTrigFP32(size, radix); + tab1.resize(SMALLTRIG_FP32_SIZE(size, 0, 0, 0)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); +#endif + +#if NTT_GF31 + vector tab2 = genSmallTrigGF31(size, radix); + tab2.resize(SMALLTRIG_GF31_SIZE(size, 0, 0, 0)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); +#endif + +#if NTT_GF61 + vector tab3 = genSmallTrigGF61(size, radix); + tab3.resize(SMALLTRIG_GF61_SIZE(size, 0, 0, 0)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); +#endif + + return tab; +} + +vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { + vector tab; + u32 tabsize; + +#if FFT_FP64 + tab = genSmallTrigComboFP64(width, middle, size, radix, tail_single_wide, tail_trigs); + tab.resize(SMALLTRIGCOMBO_FP64_SIZE(width, middle, size, radix)); +#endif + +#if FFT_FP32 + vector tab1 = genSmallTrigComboFP32(width, middle, size, radix, tail_single_wide, tail_trigs); + tab1.resize(SMALLTRIGCOMBO_FP32_SIZE(width, middle, size, radix)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); +#endif + +#if NTT_GF31 + vector tab2 = genSmallTrigComboGF31(width, middle, size, radix, tail_single_wide, tail_trigs); + tab2.resize(SMALLTRIGCOMBO_GF31_SIZE(width, middle, size, radix)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); +#endif + +#if NTT_GF61 + vector tab3 = genSmallTrigComboGF61(width, middle, size, radix, tail_single_wide, tail_trigs); + tab3.resize(SMALLTRIGCOMBO_GF61_SIZE(width, middle, size, radix)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); +#endif + + return tab; +} + +vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { + vector tab; + u32 tabsize; + +#if FFT_FP64 + tab = genMiddleTrigFP64(smallH, middle, width); + tab.resize(MIDDLETRIG_FP64_SIZE(width, middle, smallH)); +#endif + +#if FFT_FP32 + vector tab1 = genMiddleTrigFP32(smallH, middle, width); + tab1.resize(MIDDLETRIG_FP32_SIZE(width, middle, smallH)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); +#endif + +#if NTT_GF31 + vector tab2 = genMiddleTrigGF31(smallH, middle, width); + tab2.resize(MIDDLETRIG_GF31_SIZE(width, middle, smallH)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); +#endif + +#if NTT_GF61 + vector tab3 = genMiddleTrigGF61(smallH, middle, width); + tab3.resize(MIDDLETRIG_GF61_SIZE(width, middle, smallH)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); +#endif + + return tab; +} + + +/********************************************************/ +/* Code to manage a cache of trigBuffers */ +/********************************************************/ TrigBufCache::~TrigBufCache() = default; -TrigPtr TrigBufCache::smallTrig(u32 W, u32 nW) { +TrigPtr TrigBufCache::smallTrig(u32 width, u32 nW, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs) { lock_guard lock{mut}; auto& m = small; - decay_t::key_type key{W, nW, 0, 0, 0, 0}; - TrigPtr p{}; - auto it = m.find(key); - if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genSmallTrig(W, nW)); - m[key] = p; - smallCache.add(p); + + // See if there is an existing smallTrigCombo that we can return (using only as subset of the data) + // In theory, we could match any smallTrigCombo where width matches. However, SMALLTRIG_GF31_SIZE wouldn't be able to figure out the size. + // In practice, those cases will likely never arise. + if (width == height && nW == nH) { + decay_t::key_type key{height, nH, width, middle, tail_single_wide, tail_trigs}; + auto it = m.find(key); + if (it != m.end() && (p = it->second.lock())) return p; } + + // See if there is an existing non-combo smallTrig that we can return + decay_t::key_type key{width, nW, 0, 0, 0, 0}; + auto it = m.find(key); + if (it != m.end() && (p = it->second.lock())) return p; + + // Create a new non-combo + p = make_shared(context, genSmallTrig(width, nW)); + m[key] = p; + smallCache.add(p); return p; } -TrigPtr TrigBufCache::smallTrigCombo(u32 width, u32 middle, u32 W, u32 nW, u32 variant, bool tail_single_wide, u32 tail_trigs) { - if (tail_trigs == 2) // No pre-computed trig values. We might be able to share this trig table with fft_WIDTH - return smallTrig(W, nW); +TrigPtr TrigBufCache::smallTrigCombo(u32 width, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs) { + if (tail_trigs == 2 && !NTT_GF31 && !NTT_GF61) // No pre-computed trig values. We might be able to share this trig table with fft_WIDTH + return smallTrig(height, nH, middle, height, nH, variant, tail_single_wide, tail_trigs); lock_guard lock{mut}; auto& m = small; - decay_t::key_type key1{W, nW, width, middle, tail_single_wide, tail_trigs}; - // We write the "combo" under two keys, so it can also be retrieved as non-combo by smallTrig() - decay_t::key_type key2{W, nW, 0, 0, 0, 0}; + decay_t::key_type key{height, nH, width, middle, tail_single_wide, tail_trigs}; TrigPtr p{}; - auto it = m.find(key1); + auto it = m.find(key); if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genSmallTrigCombo(width, middle, W, nW, tail_single_wide, tail_trigs)); - m[key1] = p; - m[key2] = p; + p = make_shared(context, genSmallTrigCombo(width, middle, height, nH, tail_single_wide, tail_trigs)); + m[key] = p; smallCache.add(p); } return p; diff --git a/src/TrigBufCache.h b/src/TrigBufCache.h index 04ace546..3270f558 100644 --- a/src/TrigBufCache.h +++ b/src/TrigBufCache.h @@ -6,7 +6,6 @@ #include -using double2 = pair; using TrigBuf = Buffer; using TrigPtr = shared_ptr; @@ -42,12 +41,63 @@ class TrigBufCache { ~TrigBufCache(); - TrigPtr smallTrig(u32 W, u32 nW); - TrigPtr smallTrigCombo(u32 width, u32 middle, u32 W, u32 nW, u32 variant, bool tail_single_wide, u32 tail_trigs); + TrigPtr smallTrigCombo(u32 width, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs); TrigPtr middleTrig(u32 SMALL_H, u32 MIDDLE, u32 W); + TrigPtr smallTrig(u32 width, u32 nW, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs); }; -// For small angles, return "fancy" cos - 1 for increased precision -double2 root1Fancy(u32 N, u32 k); +#if FFT_FP64 +double2 root1Fancy(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision double2 root1(u32 N, u32 k); +#endif + +#if FFT_FP32 +float2 root1Fancy(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision +float2 root1(u32 N, u32 k); +#endif + +#if NTT_GF31 +uint2 root1GF31(u32 N, u32 k); +#endif + +#if NTT_GF61 +ulong2 root1GF61(u32 N, u32 k); +#endif + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of float2 values) +#define SMALLTRIG_FP64_SIZE(W,M,H,nH) (W != H || H == 0 ? W * 5 : SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH)) // See genSmallTrigFP64 +#define SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH) (H * 5 + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboFP64 +#define MIDDLETRIG_FP64_SIZE(W,M,H) (H + W + H) // See genMiddleTrigFP64 + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of float2 values) +#define SMALLTRIG_FP32_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_FP32_SIZE(W,M,H,nH)) // See genSmallTrigFP32 +#define SMALLTRIGCOMBO_FP32_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboFP32 +#define MIDDLETRIG_FP32_SIZE(W,M,H) (H + W + H) // See genMiddleTrigFP32 + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of uint2 values) +#define SMALLTRIG_GF31_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH)) // See genSmallTrigGF31 +#define SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboGF31 +#define MIDDLETRIG_GF31_SIZE(W,M,H) (H + W + H) // See genMiddleTrigGF31 + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of ulong2 values) +#define SMALLTRIG_GF61_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH)) // See genSmallTrigGF61 +#define SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboGF61 +#define MIDDLETRIG_GF61_SIZE(W,M,H) (H + W + H) // See genMiddleTrigGF61 + +// Convert above sizes to distances (in units of double2) +#define SMALLTRIG_FP64_DIST(W,M,H,nH) SMALLTRIG_FP64_SIZE(W,M,H,nH) +#define SMALLTRIGCOMBO_FP64_DIST(W,M,H,nH) SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH) +#define MIDDLETRIG_FP64_DIST(W,M,H) MIDDLETRIG_FP64_SIZE(W,M,H) + +#define SMALLTRIG_FP32_DIST(W,M,H,nH) SMALLTRIG_FP32_SIZE(W,M,H,nH) * sizeof(float) / sizeof(double) +#define SMALLTRIGCOMBO_FP32_DIST(W,M,H,nH) SMALLTRIGCOMBO_FP32_SIZE(W,M,H,nH) * sizeof(float) / sizeof(double) +#define MIDDLETRIG_FP32_DIST(W,M,H) MIDDLETRIG_FP32_SIZE(W,M,H) * sizeof(float) / sizeof(double) + +#define SMALLTRIG_GF31_DIST(W,M,H,nH) SMALLTRIG_GF31_SIZE(W,M,H,nH) * sizeof(uint) / sizeof(double) +#define SMALLTRIGCOMBO_GF31_DIST(W,M,H,nH) SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH) * sizeof(uint) / sizeof(double) +#define MIDDLETRIG_GF31_DIST(W,M,H) MIDDLETRIG_GF31_SIZE(W,M,H) * sizeof(uint) / sizeof(double) + +#define SMALLTRIG_GF61_DIST(W,M,H,nH) SMALLTRIG_GF61_SIZE(W,M,H,nH) * sizeof(ulong) / sizeof(double) +#define SMALLTRIGCOMBO_GF61_DIST(W,M,H,nH) SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH) * sizeof(ulong) / sizeof(double) +#define MIDDLETRIG_GF61_DIST(W,M,H) MIDDLETRIG_GF61_SIZE(W,M,H) * sizeof(ulong) / sizeof(double) diff --git a/src/cl/base.cl b/src/cl/base.cl index 8190406c..b92f8fb1 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -63,6 +63,11 @@ G_H "group height" == SMALL_HEIGHT / NH #endif #endif // AMDGPU +// Default is not adding -2 to results for LL +#if !defined(LL) +#define LL 0 +#endif + // On Nvidia we need the old sync between groups in carryFused #if !defined(OLD_FENCE) && !AMDGPU #define OLD_FENCE 1 @@ -144,13 +149,65 @@ typedef int i32; typedef uint u32; typedef long i64; typedef ulong u64; +typedef __int128 i128; +typedef unsigned __int128 u128; + +// Typedefs and defines for supporting hybrid FFTs +#if !defined(FFT_FP64) +#define FFT_FP64 1 +#endif +#if !defined(FFT_FP32) +#define FFT_FP32 0 +#endif +#if !defined(NTT_GF31) +#define NTT_GF31 0 +#endif +#if !defined(NTT_GF61) +#define NTT_GF61 0 +#endif +#if !defined(NTT_NCW) +#define NTT_NCW 0 +#endif +#if NTT_NCW +#error Nick Craig-Woods NTT prime is not supported now +#endif +// Data types for data stored in FFTs and NTTs during the transform +typedef double T; // For historical reasons, classic FFTs using doubles call their data T and T2. +typedef double2 T2; // A complex value using doubles in a classic FFT. +typedef float F; // A classic FFT using floats. Use typedefs F and F2. +typedef float2 F2; +typedef uint Z31; // A value calculated mod M31. For a GF(M31^2) NTT. +typedef uint2 GF31; // A complex value using two Z31s. For a GF(M31^2) NTT. +typedef ulong Z61; // A value calculated mod M61. For a GF(M61^2) NTT. +typedef ulong2 GF61; // A complex value using two Z61s. For a GF(M61^2) NTT. +//typedef ulong NCW; // A value calculated mod 2^64 - 2^32 + 1. +//typedef ulong2 NCW2; // A complex value using NCWs. For a Nick Craig-Wood's insipred NTT using prime 2^64 - 2^32 + 1. + +// Typedefs for "combo" FFT/NTTs (multiple NTT primes or hybrid FFT/NTT). +// Word and Word2 define the data type for FFT integers passed between the CPU and GPU. +#define COMBO_FFT (FFT_FP64 + FFT_FP32 + NTT_GF31 + NTT_GF61 + NTT_NCW > 1) +#if (FFT_FP64 & NTT_GF31 & !FFT_FP32 & !NTT_GF61 & !NTT_NCW) | (NTT_GF31 & NTT_GF61 & !FFT_FP64 & !FFT_FP32 & !NTT_NCW) | (FFT_FP32 & NTT_GF61 & !FFT_FP64 & !NTT_GF31 & !NTT_NCW) +#define WordSize 8 +typedef i64 Word; +typedef long2 Word2; +#elif !COMBO_FFT | (FFT_FP32 & NTT_GF31 & !FFT_FP64 & !NTT_GF61 & !NTT_NCW) +#define WordSize 4 typedef i32 Word; typedef int2 Word2; +#else +error - unsupported FFT/NTT combination +#endif -typedef double T; -typedef double2 T2; +// Routine to create a pair +double2 OVERLOAD U2(double a, double b) { return (double2) (a, b); } +float2 OVERLOAD U2(float a, float b) { return (float2) (a, b); } +int2 OVERLOAD U2(int a, int b) { return (int2) (a, b); } +long2 OVERLOAD U2(long a, long b) { return (long2) (a, b); } +uint2 OVERLOAD U2(uint a, uint b) { return (uint2) (a, b); } +ulong2 OVERLOAD U2(ulong a, ulong b) { return (ulong2) (a, b); } +// Other handy macros #define RE(a) (a.x) #define IM(a) (a.y) @@ -170,34 +227,79 @@ typedef double2 T2; #if AMDGPU typedef constant const T2* Trig; typedef constant const T* TrigSingle; +typedef constant const F2* TrigFP32; +typedef constant const GF31* TrigGF31; +typedef constant const GF61* TrigGF61; #else typedef global const T2* Trig; typedef global const T* TrigSingle; +typedef global const F2* TrigFP32; +typedef global const GF31* TrigGF31; +typedef global const GF61* TrigGF61; #endif // However, caching weights in nVidia's constant cache improves performance. // Even better is to not pollute the constant cache with weights that are used only once. // This requires two typedefs depending on how we want to use the BigTab pointer. // For AMD we can declare BigTab as constant or global - it doesn't really matter. typedef constant const double2* ConstBigTab; +typedef constant const float2* ConstBigTabFP32; #if AMDGPU typedef constant const double2* BigTab; +typedef constant const float2* BigTabFP32; #else typedef global const double2* BigTab; +typedef global const float2* BigTabFP32; #endif #define KERNEL(x) kernel __attribute__((reqd_work_group_size(x, 1, 1))) void - -void read(u32 WG, u32 N, T2 *u, const global T2 *in, u32 base) { + +#if FFT_FP64 +void OVERLOAD read(u32 WG, u32 N, T2 *u, const global T2 *in, u32 base) { in += base + (u32) get_local_id(0); for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } } -void write(u32 WG, u32 N, T2 *u, global T2 *out, u32 base) { +void OVERLOAD write(u32 WG, u32 N, T2 *u, global T2 *out, u32 base) { out += base + (u32) get_local_id(0); for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } } +#endif + +#if FFT_FP32 +void OVERLOAD read(u32 WG, u32 N, F2 *u, const global F2 *in, u32 base) { + in += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } +} -T2 U2(T a, T b) { return (T2) (a, b); } +void OVERLOAD write(u32 WG, u32 N, F2 *u, global F2 *out, u32 base) { + out += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } +} +#endif + +#if NTT_GF31 +void OVERLOAD read(u32 WG, u32 N, GF31 *u, const global GF31 *in, u32 base) { + in += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } +} + +void OVERLOAD write(u32 WG, u32 N, GF31 *u, global GF31 *out, u32 base) { + out += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } +} +#endif + +#if NTT_GF61 +void OVERLOAD read(u32 WG, u32 N, GF61 *u, const global GF61 *in, u32 base) { + in += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } +} + +void OVERLOAD write(u32 WG, u32 N, GF61 *u, global GF61 *out, u32 base) { + out += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } +} +#endif // On "classic" AMD GCN GPUs such as Radeon VII, the wavefront size was always 64. On RDNA GPUs the wavefront can // be configured to be either 64 or 32. We use the FAST_BARRIER define as an indicator for GCN GPUs. diff --git a/src/cl/carry.cl b/src/cl/carry.cl index 0a613c73..35c34ee0 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -3,33 +3,82 @@ #include "carryutil.cl" #include "weight.cl" +#if FFT_FP64 & !COMBO_FFT + // Carry propagation with optional MUL-3, over CARRY_LEN words. -// Input arrives conjugated and inverse-weighted. +// Input arrives with real and imaginary values swapped and weighted. -KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, CP(u32) bits, - BigTab THREAD_WEIGHTS, P(uint) bufROE) { +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTab THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); u32 me = get_local_id(0); u32 gx = g % NW; u32 gy = g / NW; + u32 H = BIG_HEIGHT; // & vs. && to workaround spurious warning CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; float roundMax = 0; float carryMax = 0; - // Split 32 bits into CARRY_LEN groups of 2 bits. -#define GPW (16 / CARRY_LEN) - u32 b = bits[(G_W * g + me) / GPW] >> (me % GPW * (2 * CARRY_LEN)); -#undef GPW + // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. + u32 line = gy * CARRY_LEN; + u32 fft_word_index = (gx * G_W * H + me * H + line) * 2; + u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); T base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); for (i32 i = 0; i < CARRY_LEN; ++i) { u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; - double w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); - double w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); - out[p] = weightAndCarryPair(conjugate(in[p]), U2(w1, w2), carry, &roundMax, &carry, test(b, 2 * i), test(b, 2 * i + 1), &carryMax); + T w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + T w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + bool biglit0 = frac_bits + (2*i) * FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit1 = frac_bits + (2*i) * FRAC_BPW_HI >= -FRAC_BPW_HI; // Same as frac_bits + (2*i) * FRAC_BPW_HI + FRAC_BPW_HI <= FRAC_BPW_HI; + out[p] = weightAndCarryPair(SWAP_XY(in[p]), U2(w1, w2), carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_FP32 & !COMBO_FFT + +// Carry propagation with optional MUL-3, over CARRY_LEN words. +// Input arrives with real and imaginary values swapped and weighted. + +KERNEL(G_W) carry(P(Word2) out, CP(F2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. + u32 line = gy * CARRY_LEN; + u32 fft_word_index = (gx * G_W * H + me * H + line) * 2; + u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + bool biglit0 = frac_bits + (2*i) * FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit1 = frac_bits + (2*i) * FRAC_BPW_HI >= -FRAC_BPW_HI; // Same as frac_bits + (2*i) * FRAC_BPW_HI + FRAC_BPW_HI <= FRAC_BPW_HI; + out[p] = weightAndCarryPair(SWAP_XY(in[p]), U2(w1, w2), carry, biglit0, biglit1, &carry, &roundMax, &carryMax); } carryOut[G_W * g + me] = carry; @@ -39,3 +88,484 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, CP( updateStats(bufROE, posROE, carryMax); #endif } + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & !COMBO_FFT + +KERNEL(G_W) carry(P(Word2) out, CP(GF31) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 30th root GF31. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 31; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in[p]), weight_shift0, weight_shift1, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + } + carryOut[G_W * g + me] = carry; + +#if ROE + float fltRoundMax = (float) roundMax / (float) M31; // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF61 & !COMBO_FFT + +KERNEL(G_W) carry(P(Word2) out, CP(GF61) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in[p]), weight_shift0, weight_shift1, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + } + carryOut[G_W * g + me] = carry; + +#if ROE + float fltRoundMax = (float) roundMax / (float) (M61 >> 32); // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP64 & NTT_GF31 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTab THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + T base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 31; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP64 and second GF31 weight shift + T w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + T w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in[p]), SWAP_XY(in31[p]), w1, w2, weight_shift0, weight_shift1, + carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(F2) inF2 = (CP(F2)) in; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 31; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP32 and second GF31 weight shift + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in31[p]), w1, w2, weight_shift0, weight_shift1, + carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF61 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(F2) inF2 = (CP(F2)) in; + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP32 and second GF61 weight shift + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in61[p]), w1, w2, weight_shift0, weight_shift1, + carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & NTT_GF61 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = (m31_weight_shift + log2_NWORDS + 1) % 31; + m61_weight_shift = (m61_weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in31[p]), SWAP_XY(in61[p]), m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; +// GWBUG - derive m61 weight shifts from m31 counter (or vice versa) sort of easily done from difference in the two weight shifts (no need to add frac_bits twice) + } + carryOut[G_W * g + me] = carry; + +#if ROE + float fltRoundMax = (float) roundMax / (float) 0x0FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +#else +error - missing Carry kernel implementation +#endif diff --git a/src/cl/carryb.cl b/src/cl/carryb.cl index 1f424c9f..d6033e19 100644 --- a/src/cl/carryb.cl +++ b/src/cl/carryb.cl @@ -2,19 +2,20 @@ #include "carryutil.cl" -KERNEL(G_W) carryB(P(Word2) io, CP(CarryABM) carryIn, CP(u32) bits) { +KERNEL(G_W) carryB(P(Word2) io, CP(CarryABM) carryIn) { u32 g = get_group_id(0); - u32 me = get_local_id(0); + u32 me = get_local_id(0); u32 gx = g % NW; u32 gy = g / NW; + u32 H = BIG_HEIGHT; - // Split 32 bits into CARRY_LEN groups of 2 bits. -#define GPW (16 / CARRY_LEN) - u32 b = bits[(G_W * g + me) / GPW] >> (me % GPW * (2 * CARRY_LEN)); -#undef GPW + // Derive the big vs. little flags from the fractional number of bits in each FFT word rather read the flags from memory. + // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. + u32 line = gy * CARRY_LEN; + u32 fft_word_index = (gx * G_W * H + me * H + line) * 2; + u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); - u32 step = G_W * gx + WIDTH * CARRY_LEN * gy; - io += step; + io += G_W * gx + WIDTH * CARRY_LEN * gy; u32 HB = BIG_HEIGHT / CARRY_LEN; @@ -26,7 +27,9 @@ KERNEL(G_W) carryB(P(Word2) io, CP(CarryABM) carryIn, CP(u32) bits) { for (i32 i = 0; i < CARRY_LEN; ++i) { u32 p = i * WIDTH + me; - io[p] = carryWord(io[p], &carry, test(b, 2 * i), test(b, 2 * i + 1)); + bool biglit0 = frac_bits + (2*i) * FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit1 = frac_bits + (2*i) * FRAC_BPW_HI >= -FRAC_BPW_HI; // Same as frac_bits + (2*i) * FRAC_BPW_HI + FRAC_BPW_HI <= FRAC_BPW_HI; + io[p] = carryWord(io[p], &carry, biglit0, biglit1); if (!carry) { return; } } } diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index 46bba3c3..a571aef7 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -16,6 +16,8 @@ void spin() { #endif } +#if FFT_FP64 & !COMBO_FFT + // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, @@ -26,14 +28,15 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( #else local T2 lds[WIDTH / 2]; #endif + + T2 u[NW]; + u32 gr = get_group_id(0); u32 me = get_local_id(0); u32 H = BIG_HEIGHT; u32 line = gr % H; - T2 u[NW]; - #if HAS_ASM __asm("s_setprio 3"); #endif @@ -50,9 +53,9 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding // common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs // which causes a terrible reduction in occupancy. -// fft_WIDTH(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072)); #if ZEROHACK_W - new_fft_WIDTH1(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072)); + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); #else new_fft_WIDTH1(lds, u, smallTrig); #endif @@ -82,12 +85,12 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( float roundMax = 0; float carryMax = 0; - // On Titan V it is faster to derive the big vs. little flags from the fractional number of bits in each FFT word rather read the flags from memory. + // On Titan V it is faster to derive the big vs. little flags from the fractional number of bits in each FFT word rather than read the flags from memory. // On Radeon VII this code is about the same speed. Not sure which is better on other GPUs. #if BIGLIT - // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. - u32 fft_word_index = (me * H + line) * 2; - u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); + // Calculate the most significant 32-bits of FRAC_BPW * the word index. Also add FRAC_BPW_HI to test first biglit flag. + u32 word_index = (me * H + line) * 2; + u32 frac_bits = word_index * FRAC_BPW_HI + mad_hi (word_index, FRAC_BPW_LO, FRAC_BPW_HI); #endif // Apply the inverse weights and carry propagate pairs to generate the output carries @@ -110,10 +113,10 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. - wu[i] = weightAndCarryPairSloppy(conjugate(u[i]), U2(invWeight1, invWeight2), + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), U2(invWeight1, invWeight2), // For an LL test, add -2 as the very initial "carry in" // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it - (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, &roundMax, &carry[i], biglit0, biglit1, &carryMax); + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); } #if ROE @@ -212,7 +215,197 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( bool biglit0 = test(b, 2 * i); #endif wu[i] = carryFinal(wu[i], carry[i], biglit0); - u[i] *= U2(wu[i].x, wu[i].y); + u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); + } + + bar(); + +// fft_WIDTH(lds, u, smallTrig); + new_fft_WIDTH2(lds, u, smallTrig); + + writeCarryFusedLine(u, out, line); +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_FP32 & !COMBO_FFT + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(F2) out, CP(F2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, TrigFP32 smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local F2 lds[WIDTH / 4]; +#else + local F2 lds[WIDTH / 2]; +#endif + + F2 u[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + + P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; + CFcarry carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + // Calculate the most significant 32-bits of FRAC_BPW * the word index. Also add FRAC_BPW_HI to test first biglit flag. + u32 word_index = (me * H + line) * 2; + u32 frac_bits = word_index * FRAC_BPW_HI + mad_hi (word_index, FRAC_BPW_LO, FRAC_BPW_HI); + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + + for (u32 i = 0; i < NW; ++i) { + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + + // Generate big-word/little-word flags + bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; + bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP >= -FRAC_BPW_HI; // Same as frac_bits + i * FRAC_BITS_BIGSTEP + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), U2(invWeight1, invWeight2), + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + } + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + u[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words + for (i32 i = 0; i < NW; ++i) { + bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); } bar(); @@ -222,3 +415,1444 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( writeCarryFusedLine(u, out, line); } + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & !COMBO_FFT + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(GF31) out, CP(GF31) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, TrigGF31 smallTrig, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF31 lds[WIDTH / 4]; +#else + local GF31 lds[WIDTH / 2]; +#endif + + GF31 u[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); +#endif + + Word2 wu[NW]; + +#if MUL3 + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; +#else + P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; + CFcarry carry[NW+1]; +#endif + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF31. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 31) weight_shift -= 31; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + for (u32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + float fltRoundMax = (float) roundMax / (float) M31; // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(shl(make_Z31(wu[i].x), weight_shift0), shl(make_Z31(wu[i].y), weight_shift1)); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + bar(); + + new_fft_WIDTH2(lds, u, smallTrig); + + writeCarryFusedLine(u, out, line); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF61 & !COMBO_FFT + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(GF61) out, CP(GF61) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, TrigGF61 smallTrig, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF61 lds[WIDTH / 4]; +#else + local GF61 lds[WIDTH / 2]; +#endif + + GF61 u[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); +#endif + + Word2 wu[NW]; + +#if MUL3 + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; +#else + P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; + CFcarry carry[NW+1]; +#endif + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 61) weight_shift -= 61; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + for (u32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + float fltRoundMax = (float) roundMax / (float) (M61 >> 32); // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(shl(make_Z61(wu[i].x), weight_shift0), shl(make_Z61(wu[i].y), weight_shift1)); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + bar(); + + new_fft_WIDTH2(lds, u, smallTrig); + + writeCarryFusedLine(u, out, line); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP64 & NTT_GF31 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTab CONST_THREAD_WEIGHTS, BigTab THREAD_WEIGHTS, P(uint) bufROE) { + + local T2 lds[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) lds; + + T2 u[NW]; + GF31 u31[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + readCarryFusedLine(in31, u31, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); + bar(); + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); + bar(); + new_fft_WIDTH1(lds31, u31, smallTrig31); +#endif + + Word2 wu[NW]; +#if AMDGPU + T2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + T2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 31) weight_shift -= 31; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + T invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP64 weights and second GF31 weight shift + T invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + T invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), SWAP_XY(u31[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + T base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + T weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + T weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + u[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); + u31[i] = U2(shl(make_Z31(wu[i].x), weight_shift0), shl(make_Z31(wu[i].y), weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + bar(); + + new_fft_WIDTH2(lds, u, smallTrig); + writeCarryFusedLine(u, out, line); + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + + local F2 ldsF2[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) ldsF2; + + F2 uF2[NW]; + GF31 u31[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(inF2, uF2, line); + readCarryFusedLine(in31, u31, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(ldsF2 + zerohack, uF2, smallTrigF2 + zerohack); + bar(); + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); +#else + new_fft_WIDTH1(ldsF2, uF2, smallTrigF2); + bar(); + new_fft_WIDTH1(lds31, u31, smallTrig31); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i32) carryShuttlePtr = (P(i32)) carryShuttle; + i32 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 31) weight_shift -= 31; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP32 weights and second GF31 weight shift + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u31[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + uF2[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + uF2[i] = U2(uF2[i].x * wu[i].x, uF2[i].y * wu[i].y); + u31[i] = U2(shl(make_Z31(wu[i].x), weight_shift0), shl(make_Z31(wu[i].y), weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + bar(); + + new_fft_WIDTH2(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, line); + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF61 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + + local GF61 lds61[WIDTH / 2]; + local F2 *ldsF2 = (local F2 *) lds61; + + F2 uF2[NW]; + GF61 u61[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(inF2, uF2, line); + readCarryFusedLine(in61, u61, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(ldsF2 + zerohack, uF2, smallTrigF2 + zerohack); + bar(); + new_fft_WIDTH1(lds61 + zerohack, u61, smallTrig61 + zerohack); +#else + new_fft_WIDTH1(ldsF2, uF2, smallTrigF2); + bar(); + new_fft_WIDTH1(lds61, u61, smallTrig61); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 61) weight_shift -= 61; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP32 weights and second GF61 weight shift + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u61[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + uF2[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + uF2[i] = U2(uF2[i].x * wu[i].x, uF2[i].y * wu[i].y); + u61[i] = U2(shl(make_Z61(wu[i].x), weight_shift0), shl(make_Z61(wu[i].y), weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + bar(); + + new_fft_WIDTH2(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, line); + + bar(); + + new_fft_WIDTH2(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, line); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & NTT_GF61 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF61 lds61[WIDTH / 4]; +#else + local GF61 lds61[WIDTH / 2]; +#endif + local GF31 *lds31 = (local GF31 *) lds61; + + GF31 u31[NW]; + GF61 u61[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in31, u31, line); + readCarryFusedLine(in61, u61, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); + bar(); + new_fft_WIDTH1(lds61 + zerohack, u61, smallTrig61 + zerohack); +#else + new_fft_WIDTH1(lds31, u31, smallTrig31); + bar(); + new_fft_WIDTH1(lds61, u61, smallTrig61); +#endif + + Word2 wu[NW]; + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * H * 2 - 1) * m31_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + u64 m31_starting_combo_counter = m31_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * H * 2 - 1) * m61_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + u64 m61_starting_combo_counter = m61_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = m31_weight_shift + log2_NWORDS + 1; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_weight_shift = m61_weight_shift + log2_NWORDS + 1; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + for (u32 i = 0; i < NW; ++i) { + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u31[i]), SWAP_XY(u61[i]), m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_bigstep; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + } + m31_combo_counter = m31_starting_combo_counter; // Restore starting counter for applying weights after carry propagation + m61_combo_counter = m61_starting_combo_counter; + +#if ROE + float fltRoundMax = (float) roundMax / (float) 0x0FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float - divide by M61*M31. + updateStats(bufROE, posROE, fltRoundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u31[i] = U2(shl(make_Z31(wu[i].x), m31_weight_shift0), shl(make_Z31(wu[i].y), m31_weight_shift1)); + u61[i] = U2(shl(make_Z61(wu[i].x), m61_weight_shift0), shl(make_Z61(wu[i].y), m61_weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_bigstep; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + } + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); + + bar(); + + new_fft_WIDTH2(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, line); +} + + +#else +error - missing CarryFused kernel implementation +#endif diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index 1ee8e440..0b7241f5 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -2,16 +2,6 @@ // This file is included with different definitions for iCARRY -Word2 OVERLOAD carryPair(long2 u, iCARRY *outCarry, bool b1, bool b2, float* carryMax) { - iCARRY midCarry; - Word a = carryStep(u.x, &midCarry, b1); - Word b = carryStep(u.y + midCarry, outCarry, b2); -// #if STATS & 0x5 - *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); -// #endif - return (Word2) (a, b); -} - Word2 OVERLOAD carryFinal(Word2 u, iCARRY inCarry, bool b1) { i32 tmpCarry; u.x = carryStep(u.x + inCarry, &tmpCarry, b1); @@ -19,10 +9,11 @@ Word2 OVERLOAD carryFinal(Word2 u, iCARRY inCarry, bool b1) { return u; } +#if FFT_FP64 & !COMBO_FFT // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. -Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, float* maxROE, iCARRY *outCarry, bool b1, bool b2, float* carryMax) { +Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); Word a = carryStep(tmp1, &midCarry, b1); @@ -32,8 +23,8 @@ Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, float* maxROE return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accuracy calculation of the first carry is not required. -Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, float* maxROE, iCARRY *outCarry, bool b1, bool b2, float* carryMax) { +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); Word a = carryStepSloppy(tmp1, &midCarry, b1); @@ -42,3 +33,225 @@ Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, float* *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_FP32 & !COMBO_FFT + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 u, F2 invWeight, iCARRY inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i32 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); + Word a = carryStep(tmp1, &midCarry, b1); + i32 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 u, F2 invWeight, iCARRY inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i32 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i32 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & !COMBO_FFT + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(GF31 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF61 & !COMBO_FFT + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(GF61 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(GF61 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP64 & NTT_GF31 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + i32 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i64 tmp1 = weightAndCarryOne(uF2.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(uF2.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + i32 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i64 tmp1 = weightAndCarryOne(uF2.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(uF2.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF61 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, inCarry, maxROE); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & NTT_GF61 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + Word a = carryStepSloppy(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +#else +error - missing carryinc implementation +#endif diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index ed2545f7..0a790db3 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -3,18 +3,16 @@ #include "base.cl" #include "math.cl" -#if STATS || ROE -void updateStats(global uint *bufROE, u32 posROE, float roundMax) { - assert(roundMax >= 0); - // work_group_reduce_max() allocates an additional 256Bytes LDS for a 64lane workgroup, so avoid it. - // u32 groupRound = work_group_reduce_max(as_uint(roundMax)); - // if (get_local_id(0) == 0) { atomic_max(bufROE + posROE, groupRound); } - - // Do the reduction directly over global mem. - atomic_max(bufROE + posROE, as_uint(roundMax)); -} +#if CARRY64 +typedef i64 CFcarry; +#else +typedef i32 CFcarry; #endif +// The carry for the non-fused CarryA, CarryB, CarryM kernels. +// Simply use large carry always as the split kernels are slow anyway (and seldomly used normally). +typedef i64 CarryABM; + #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) i32 lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } #else @@ -22,9 +20,9 @@ i32 lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } #endif #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) -i32 ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } +u32 ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } #else -i32 ulowBits(i32 u, u32 bits) { u32 uu = (u32) u; return ((uu << (32 - bits)) >> (32 - bits)); } +u32 ulowBits(i32 u, u32 bits) { return (((u32) u << (32 - bits)) >> (32 - bits)); } #endif #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_alignbit) @@ -33,10 +31,6 @@ i32 xtract32(i64 x, u32 bits) { return __builtin_amdgcn_alignbit(as_int2(x).y, a i32 xtract32(i64 x, u32 bits) { return x >> bits; } #endif -#if !defined(LL) -#define LL 0 -#endif - u32 bitlen(bool b) { return EXP / NWORDS + b; } bool test(u32 bits, u32 pos) { return (bits >> pos) & 1; } @@ -53,6 +47,170 @@ void ROUNDOFF_CHECK(double x) { } #endif +Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + +//GWBUG - is this ever faster? +//i128 x128 = ((i128) (i64) i96_hi64(x) << 32) | i96_lo32(x); +//i64 w = ((i64) x128 << (64 - nBits)) >> (64 - nBits); +//x128 -= w; +//*outCarry = x128 >> nBits; +//return w; + +// This code is tricky because me must not shift i32 or u32 variables by 32. +#if EXP / NWORDS >= 33 //GWBUG Would the EXP / NWORDS == 32 code be just as fast? + i64 xhi = i96_hi64(x); + i64 w = lowBits(xhi, nBits - 32); + xhi -= w; + *outCarry = xhi >> (nBits - 32); + return (w << 32) | i96_lo32(x); +#elif EXP / NWORDS == 32 + i64 xhi = i96_hi64(x); + i64 w = ((i64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); +// xhi -= w >> 32; +// *outCarry = xhi >> (nBits - 32); //GWBUG - Would adding (w < 0) be faster than subtracting w>>32 from xhi? + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return w; +#elif EXP / NWORDS == 31 + i64 w = ((i64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); + *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); + return w; +#else + i32 w = lowBits(i96_lo32(x), nBits); + *outCarry = ((i96_hi64(x) << (32 - nBits)) | (i96_lo32(x) >> nBits)) + (w < 0); + return w; +#endif +} + +Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 33 + i32 xhi = (x >> 32); + i32 w = lowBits(xhi, nBits - 32); + xhi -= w; + *outCarry = xhi >> (nBits - 32); + return (Word) (((u64) w << 32) | (u32)(x)); +#elif EXP / NWORDS == 32 + i32 xhi = (x >> 32); + i64 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); + xhi -= w >> 32; + *outCarry = xhi >> (nBits - 32); + return w; +#elif EXP / NWORDS == 31 + i64 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); + x -= w; + *outCarry = x >> nBits; + return w; +#else + Word w = lowBits((i32) x, nBits); + x -= w; + *outCarry = x >> nBits; + return w; +#endif +} + +Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 33 + i32 xhi = (x >> 32); + i32 w = lowBits(xhi, nBits - 32); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return (Word) (((u64) w << 32) | (u32)(x)); +#elif EXP / NWORDS == 32 + i32 xhi = (x >> 32); + i64 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); + *outCarry = (i32) (xhi >> (nBits - 32)) + (w < 0); + return w; +#elif EXP / NWORDS == 31 + i32 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); + *outCarry = (i32) (x >> nBits) + (w < 0); + return w; +#else + Word w = lowBits(x, nBits); + *outCarry = xtract32(x, nBits) + (w < 0); + return w; +#endif +} + +Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + Word w = lowBits(x, nBits); + *outCarry = (x - w) >> nBits; + return w; +} + +Word OVERLOAD carryStepSloppy(i96 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + +// GWBUG Is this faster (or same speed) ???? This code doesn't work on TitanV??? +//i128 x128 = ((i128) xhi << 32) | i96_lo32(x); +//*outCarry = x128 >> nBits; +//return ((u64) x128 << (64 - nBits)) >> (64 - nBits); + +// This code is tricky because me must not shift i32 or u32 variables by 32. +#if EXP / NWORDS >= 33 // nBits is 33 or more + i64 xhi = i96_hi64(x); + *outCarry = xhi >> (nBits - 32); + return (Word) (((u64) ulowBits((i32) xhi, nBits - 32) << 32) | i96_lo32(x)); +#elif EXP / NWORDS == 32 // nBits = 32 or 33 + i64 xhi = i96_hi64(x); + *outCarry = xhi >> (nBits - 32); + u64 xlo = i96_lo64(x); + return (xlo << (64 - nBits)) >> (64 - nBits); // ulowBits(xlo, nBits); +#elif EXP / NWORDS == 31 // nBits = 31 or 32 + *outCarry = (i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16)); + return ((u64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); // ulowBits(xlo, nBits); +#else // nBits less than 32 + *outCarry = (i96_hi64(x) << (32 - nBits)) | (i96_lo32(x) >> nBits); + return ulowBits(i96_lo32(x), nBits); +#endif +} + +Word OVERLOAD carryStepSloppy(i64 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = x >> nBits; + return ulowBits(x, nBits); +} + +Word OVERLOAD carryStepSloppy(i64 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = xtract32(x, nBits); + return ulowBits(x, nBits); +} + +Word OVERLOAD carryStepSloppy(i32 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = x >> nBits; + return ulowBits(x, nBits); +} + +// Carry propagation from word and carry. +Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { + a.x = carryStep(a.x + *carry, carry, b1); + a.y = carryStep(a.y + *carry, carry, b2); + return a; +} + +// map abs(carry) to floats, with 2^32 corresponding to 1.0 +// So that the maximum CARRY32 abs(carry), 2^31, is mapped to 0.5 (the same as the maximum ROE) +float OVERLOAD boundCarry(i32 c) { return ldexp(fabs((float) c), -32); } +float OVERLOAD boundCarry(i64 c) { return ldexp(fabs((float) (i32) (c >> 8)), -24); } + +#if STATS || ROE +void updateStats(global uint *bufROE, u32 posROE, float roundMax) { + assert(roundMax >= 0); + // work_group_reduce_max() allocates an additional 256Bytes LDS for a 64lane workgroup, so avoid it. + // u32 groupRound = work_group_reduce_max(as_uint(roundMax)); + // if (get_local_id(0) == 0) { atomic_max(bufROE + posROE, groupRound); } + + // Do the reduction directly over global mem. + atomic_max(bufROE + posROE, as_uint(roundMax)); +} +#endif + + +#if FFT_FP64 + // Rounding constant: 3 * 2^51, See https://stackoverflow.com/questions/17035464 #define RNDVAL (3.0 * (1l << 51)) @@ -70,6 +228,10 @@ i64 RNDVALdoubleToLong(double d) { return as_long(words); } +#endif + +#if FFT_FP64 & !COMBO_FFT + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_result_is_acceptable) { @@ -105,76 +267,336 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r #endif } -Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = lowBits(x, nBits); - x -= w; - *outCarry = x >> nBits; + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_FP32 & !COMBO_FFT + +// Rounding constant: 3 * 2^22 +#define RNDVAL (3.0f * (1 << 22)) + +// Convert a float to int efficiently. Float must be in RNDVAL+integer format. +i32 RNDVALfloatToInt(float d) { + int w = as_int(d); +//#if 0 +// We extend the range to 23 bits instead of 22 by taking the sign from the negation of bit 22 +// w ^= 0x00800000u; +// w = lowBits(words.y, 23); +//#else +// // Take the sign from bit 21 (i.e. use lower 22 bits). + w = lowBits(w, 22); +//#endif return w; } -Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = lowBits(x, nBits); - *outCarry = xtract32(x, nBits) + (w < 0); - return w; +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_result_is_acceptable) { + +#if !MUL3 + + // Convert carry into RNDVAL + carry. + float RNDVALCarry = as_float(as_int(RNDVAL) + inCarry); // GWBUG - just the float arithmetic? s.b. fast + + // Apply inverse weight and RNDVAL+carry + float d = fma(u, invWeight, RNDVALCarry); + + // Optionally calculate roundoff error + float roundoff = fabs(fma(u, -invWeight, d - RNDVALCarry)); + *maxROE = max(*maxROE, roundoff); + + // Convert to int + return RNDVALfloatToInt(d); + +#else // We cannot add in the carry until after the mul by 3 + + // Apply inverse weight and RNDVAL + float d = fma(u, invWeight, RNDVAL); + + // Optionally calculate roundoff error + float roundoff = fabs(fma(u, -invWeight, d - RNDVAL)); + *maxROE = max(*maxROE, roundoff); + + // Convert to int, mul by 3, and add carry + return RNDVALfloatToInt(d) * 3 + inCarry; + +#endif } -Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = lowBits(x, nBits); - *outCarry = (x - w) >> nBits; - return w; + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & !COMBO_FFT + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { + + // Apply inverse weight + u = shr(u, invWeight); + + // Convert input to balanced representation + i32 value = get_balanced_Z31(u); + + // Optionally calculate roundoff error as proximity to M31/2. + u32 roundoff = (u32) abs(value); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + value *= 3; +#endif + return value + inCarry; } -Word OVERLOAD carryStepSloppy(i64 x, i64 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = ulowBits(x, nBits); - *outCarry = x >> nBits; - return w; + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF61 & !COMBO_FFT + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { + + // Apply inverse weight + u = shr(u, invWeight); + + // Convert input to balanced representation + i64 value = get_balanced_Z61(u); + + // Optionally calculate roundoff error as proximity to M61/2. 28 bits of accuracy should be sufficient. + u32 roundoff = (u32) abs((i32) (value >> 32)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + value *= 3; +#endif + return value + inCarry; } -Word OVERLOAD carryStepSloppy(i64 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = ulowBits(x, nBits); - *outCarry = xtract32(x, nBits); - return w; + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP64 & NTT_GF31 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, float* maxROE) { + + // Apply inverse weight and get the Z31 data + u31 = shr(u31, m31_invWeight); + u32 n31 = get_Z31(u31); + + // The final result must be n31 mod M31. Use FP64 data to calculate this value. + u = u * invWeight - (double) n31; // This should be close to a multiple of M31 + u *= 4.656612875245796924105750827168e-10; // Divide by M31. Could divide by 2^31 (0.0000000004656612873077392578125) be accurate enough? //GWBUG - check the generated code! Use 1/M31??? + + i64 n64 = RNDVALdoubleToLong(u + RNDVAL); + + i128 v = ((i128) n64 << 31) - n64; // n64 * M31 + v += n31; + + // Optionally calculate roundoff error + float roundoff = (float) fabs(u - (double) n64); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + v = v * 3; +#endif + v += inCarry; + i96 value = make_i96((u64) (v >> 32), (u32) v); + return value; } -Word OVERLOAD carryStepSloppy(i32 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = ulowBits(x, nBits); - *outCarry = x >> nBits; - return w; + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, i32 inCarry, float* maxROE) { + + // Apply inverse weight and get the Z31 data + u31 = shr(u31, m31_invWeight); + u32 n31 = get_Z31(u31); + + // The final result must be n31 mod M31. Use FP32 data to calculate this value. + uF2 = uF2 * F2_invWeight - (float) n31; // This should be close to a multiple of M31 + uF2 *= 0.0000000004656612873077392578125f; // Divide by 2^31 //GWBUG - check the generated code! + +// i32 nF2 = rint(uF2); // GWBUG - does this round cheaply? Best way to round? +// Rounding constant: 3 * 2^22 +#define RNDVAL (3.0f * (1 << 22)) + i32 nF2 = lowBits(as_int(uF2 + RNDVAL), 22); + + i64 v = ((i64) nF2 << 31) - nF2; // nF2 * M31 + v += n31; + + // Optionally calculate roundoff error + float roundoff = fabs(uF2 - nF2); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + v = v * 3; +#endif + return v + inCarry; } -// map abs(carry) to floats, with 2^32 corresponding to 1.0 -// So that the maximum CARRY32 abs(carry), 2^31, is mapped to 0.5 (the same as the maximum ROE) -float OVERLOAD boundCarry(i32 c) { return ldexp(fabs((float) c), -32); } -float OVERLOAD boundCarry(i64 c) { return ldexp(fabs((float) (i32) (c >> 8)), -24); } -#define iCARRY i32 -#include "carryinc.cl" -#undef iCARRY +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ -#define iCARRY i64 -#include "carryinc.cl" -#undef iCARRY +#elif FFT_FP32 & NTT_GF61 -#if CARRY64 -typedef i64 CFcarry; +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { + + // Apply inverse weight and get the Z61 data + u61 = shr(u61, m61_invWeight); + u64 n61 = get_Z61(u61); + +#if 0 +BUG - need more than 64 bit integers + + // The final result must be n61 mod M61. Use FP32 data to calculate this value. + uF2 = uF2 * F2_invWeight - (float) n61; // This should be close to a multiple of M61 + uF2 *= 4.3368086899420177360298112034798e-19f; // Divide by 2^61 //GWBUG - check the generated code! + +// i32 nF2 = rint(uF2); // GWBUG - does this round cheaply? Best way to round? +// Rounding constant: 3 * 2^22 +#define RNDVAL (3.0f * (1 << 22)) + i32 nF2 = lowBits(as_int(uF2 + RNDVAL), 22); + + i64 v = ((i64) nF2 << 61) - nF2; // nF2 * M61 + v += n61; + + // Optionally calculate roundoff error + float roundoff = fabs(uF2 - (float) nF2); + *maxROE = max(*maxROE, roundoff); #else -typedef i32 CFcarry; + + // The final result must be n61 mod M61. Use FP32 data to calculate this value. +#define RNDVAL (3.0 * (1l << 51)) + double uuF2 = (double) uF2 * (double) F2_invWeight - (double) n61; // This should be close to a multiple of M61 + uuF2 = uuF2 * 4.3368086899420177360298112034798e-19; // Divide by 2^61 //GWBUG - check the generated code! +volatile double xxF2 = uuF2 + RNDVAL; // Divide by 2^61 //GWBUG - check the generated code! + xxF2 -= RNDVAL; + i32 nF2 = (int) xxF2; + + i128 v = ((i128) nF2 << 61) - nF2; // nF2 * M61 + v += n61; + + // Optionally calculate roundoff error + float roundoff = (float) fabs(uuF2 - (double) nF2); + *maxROE = max(*maxROE, roundoff); #endif -// The carry for the non-fused CarryA, CarryB, CarryM kernels. -// Simply use large carry always as the split kernels are slow anyway (and seldomly used normally). -typedef i64 CarryABM; + // Mul by 3 and add carry +#if MUL3 + v = v * 3; +#endif + v += inCarry; + i96 value = make_i96((u64) (v >> 32), (u32) v); + return value; +} -// Carry propagation from word and carry. -Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { - a.x = carryStep(a.x + *carry, carry, b1); - a.y = carryStep(a.y + *carry, carry, b2); - return a; + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & NTT_GF61 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, u32* maxROE) { + + // Apply inverse weights + u31 = shr(u31, m31_invWeight); + u61 = shr(u61, m61_invWeight); + + // Use chinese remainder theorem to create a 92-bit result. Loosely copied from Yves Gallot's mersenne2 program. + u32 n31 = get_Z31(u31); + u61 = sub(u61, make_Z61(n31)); // u61 - u31 + u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) + u64 n61 = get_Z61(u61); + +#if INT128_MATH +i128 v = ((i128) n61 << 31) + n31 - n61; //GWBUG - is this better/as good as int96 code? +// +// i96 value = make_i96(n61 >> 1, ((u32) n61 << 31) | n31); // (n61<<31) + n31 +// i96_sub(&value, n61); + + // Convert to balanced representation by subtracting M61*M31 +if ((v >> 64) & 0xF8000000) v = v - (i128) M31 * (i128) M61; +// if (i96_hi32(value) & 0xF8000000) i96_sub(&value, make_i96(0x0FFFFFFF, 0xDFFFFFFF, 0x80000001)); + + // Optionally calculate roundoff error as proximity to M61*M31/2. 27 bits of accuracy should be sufficient. +// u32 roundoff = (u32) abs((i32) i96_hi32(value)); +u32 roundoff = (u32) abs((i32)(v >> 64)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 +v = v * 3; +// i96_mul(&value, 3); +#endif +// i96_add(&value, make_i96((u32)(inCarry >> 63), (u64) inCarry)); +v = v + inCarry; +i96 value = make_i96((u64) (v >> 32), (u32) v); + +#else + + i96 value = make_i96(n61 >> 1, ((u32) n61 << 31) | n31); // (n61<<31) + n31 + i96_sub(&value, n61); + + // Convert to balanced representation by subtracting M61*M31 + if (i96_hi32(value) & 0xF8000000) i96_sub(&value, make_i96(0x0FFFFFFF, 0xDFFFFFFF, 0x80000001)); + + // Optionally calculate roundoff error as proximity to M61*M31/2. 27 bits of accuracy should be sufficient. + u32 roundoff = (u32) abs((i32) i96_hi32(value)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + i96_mul(&value, 3); +#endif + i96_add(&value, make_i96((u32)(inCarry >> 63), (u64) inCarry)); + +#endif + + return value; } + +#else +error - missing carryUtil implementation +#endif + + + +/**************************************************************************/ +/* Do this last, it depends on weightAndCarryOne defined above */ +/**************************************************************************/ + +/* Support both 32-bit and 64-bit carries */ + +#if WordSize <= 4 +#define iCARRY i32 +#include "carryinc.cl" +#undef iCARRY +#endif + +#define iCARRY i64 +#include "carryinc.cl" +#undef iCARRY + diff --git a/src/cl/etc.cl b/src/cl/etc.cl index 89fcaf7b..ae4fd857 100644 --- a/src/cl/etc.cl +++ b/src/cl/etc.cl @@ -18,7 +18,7 @@ KERNEL(32) readResidue(P(Word2) out, CP(Word2) in) { #if SUM64 KERNEL(64) sum64(global ulong* out, u32 sizeBytes, global ulong* in) { if (get_global_id(0) == 0) { out[0] = 0; } - + ulong sum = 0; for (i32 p = get_global_id(0); p < sizeBytes / sizeof(u64); p += get_global_size(0)) { sum += in[p]; @@ -32,7 +32,7 @@ KERNEL(64) sum64(global ulong* out, u32 sizeBytes, global ulong* in) { #if ISEQUAL // outEqual must be "true" on entry. KERNEL(256) isEqual(global i64 *in1, global i64 *in2, P(int) outEqual) { - for (i32 p = get_global_id(0); p < ND; p += get_global_size(0)) { + for (i32 p = get_global_id(0); p < NWORDS * sizeof(Word) / sizeof(i64); p += get_global_size(0)) { if (in1[p] != in2[p]) { *outEqual = 0; return; @@ -50,5 +50,4 @@ kernel void testKernel(global int* in, global double* out) { int p = me * in[me] % 8; // % 15; out[me] = TAB[p]; } - #endif diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index bf61d555..f4634af3 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -2,8 +2,6 @@ #include "trig.cl" -void fft2(T2* u) { X2(u[0], u[1]); } - #if MIDDLE == 3 #include "fft3.cl" #elif MIDDLE == 4 @@ -34,7 +32,29 @@ void fft2(T2* u) { X2(u[0], u[1]); } #include "fft16.cl" #endif -void fft_MIDDLE(T2 *u) { +#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 0 +#define MM_CHAIN 0 +#define MM2_CHAIN 0 +#endif + +#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 1 +#define MM_CHAIN 1 +#define MM2_CHAIN 2 +#endif + +// Apply the twiddles needed after fft_MIDDLE and before fft_HEIGHT in forward FFT. +// Also used after fft_HEIGHT and before fft_MIDDLE in inverse FFT. + +#define WADD(i, w) u[i] = cmul(u[i], w) +#define WSUB(i, w) u[i] = cmul_by_conjugate(u[i], w) +#define WADDF(i, w) u[i] = cmulFancy(u[i], w) +#define WSUBF(i, w) u[i] = cmulFancy(u[i], conjugate(w)) + +#if FFT_FP64 + +void OVERLOAD fft2(T2* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(T2 *u) { #if MIDDLE == 1 // Do nothing #elif MIDDLE == 2 @@ -72,28 +92,15 @@ void fft_MIDDLE(T2 *u) { #endif } -// Apply the twiddles needed after fft_MIDDLE and before fft_HEIGHT in forward FFT. -// Also used after fft_HEIGHT and before fft_MIDDLE in inverse FFT. - -#define WADD(i, w) u[i] = cmul(u[i], w) -#define WSUB(i, w) u[i] = cmul_by_conjugate(u[i], w) - -#define WADDF(i, w) u[i] = cmulFancy(u[i], w) -#define WSUBF(i, w) u[i] = cmulFancy(u[i], conjugate(w)) - // Keep in sync with TrigBufCache.cpp, see comment there. #define SHARP_MIDDLE 5 -#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 1 -#define MM_CHAIN 1 -#define MM2_CHAIN 2 -#endif - -void middleMul(T2 *u, u32 s, Trig trig) { +void OVERLOAD middleMul(T2 *u, u32 s, Trig trig) { assert(s < SMALL_HEIGHT); if (MIDDLE == 1) { return; } - T2 w = trig[s]; // s / BIG_HEIGHT + if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. + T2 w = trig[s]; // s / BIG_HEIGHT if (MIDDLE < SHARP_MIDDLE) { WADD(1, w); @@ -175,7 +182,7 @@ void middleMul(T2 *u, u32 s, Trig trig) { } } -void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { +void OVERLOAD middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { assert(x < WIDTH); assert(y < SMALL_HEIGHT); @@ -184,7 +191,8 @@ void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { return; } - T2 w = trig[SMALL_HEIGHT + x]; // x / (MIDDLE * WIDTH) + trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table + T2 w = trig[x]; // x / (MIDDLE * WIDTH) if (MIDDLE < SHARP_MIDDLE) { T2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; @@ -196,7 +204,19 @@ void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { } else { // MIDDLE >= 5 // T2 w = slowTrig_N(x * SMALL_HEIGHT, ND / MIDDLE); -#if AMDGPU && MM2_CHAIN == 0 // Oddly, Radeon 7 is faster with this version that uses more F64 ops +#if 0 // Slower on Radeon 7, but proves the concept for use in GF61. Might be worthwhile on poor FP64 GPUs + + Trig trig2 = trig + WIDTH; // Skip over the fist MiddleMul2 trig table + u32 desired_root = x * y; + T2 base = cmulFancy(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]) * factor; //Optimization to do: put multiply by factor in trig2 table + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + +#elif AMDGPU && MM2_CHAIN == 0 // Oddly, Radeon 7 is faster with this version that uses more F64 ops T2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; WADD(0, base); @@ -281,13 +301,8 @@ void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { } } -#undef WADD -#undef WADDF -#undef WSUB -#undef WSUBF - // Do a partial transpose during fftMiddleIn/Out -void middleShuffle(local T *lds, T2 *u, u32 workgroupSize, u32 blockSize) { +void OVERLOAD middleShuffle(local T *lds, T2 *u, u32 workgroupSize, u32 blockSize) { u32 me = get_local_id(0); if (MIDDLE <= 8) { local T *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; @@ -323,10 +338,505 @@ void middleShuffle(local T *lds, T2 *u, u32 workgroupSize, u32 blockSize) { } } +// Do a partial transpose during fftMiddleIn/Out and write the results to global memory +void OVERLOAD middleShuffleWrite(global T2 *out, T2 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft2(F2* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(F2 *u) { +#if MIDDLE == 1 + // Do nothing +#elif MIDDLE == 2 + fft2(u); +#elif MIDDLE == 4 + fft4(u); +#elif MIDDLE == 8 + fft8(u); +#elif MIDDLE == 16 + fft16(u); +#else +#error UNRECOGNIZED MIDDLE +#endif +} + +// Keep in sync with TrigBufCache.cpp, see comment there. +#define SHARP_MIDDLE 5 + +void OVERLOAD middleMul(F2 *u, u32 s, TrigFP32 trig) { + assert(s < SMALL_HEIGHT); + if (MIDDLE == 1) { return; } + + if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. + F2 w = trig[s]; // s / BIG_HEIGHT + + if (MIDDLE < SHARP_MIDDLE) { + WADD(1, w); +#if MM_CHAIN == 0 + F2 base = csqTrig(w); + for (u32 k = 2; k < MIDDLE; ++k) { + WADD(k, base); + base = cmul(base, w); + } +#elif MM_CHAIN == 1 + for (u32 k = 2; k < MIDDLE; ++k) { WADD(k, slowTrig_N(WIDTH * k * s, WIDTH * k * SMALL_HEIGHT)); } +#else +#error MM_CHAIN must be 0 or 1 +#endif + + } else { // MIDDLE >= 5 + +#if MM_CHAIN == 0 + WADDF(1, w); + F2 base; + base = csqTrigFancy(w); + WADDF(2, base); + base = ccubeTrigFancy(base, w); + WADDF(3, base); + base.x += 1; + + for (u32 k = 4; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + +#elif 0 && MM_CHAIN == 1 // This is fewer F64 ops, but may be slower on Radeon 7 -- probably the optimizer being weird. It also has somewhat worse Z. + for (u32 k = 3 + (MIDDLE - 2) % 3; k < MIDDLE; k += 3) { + F2 base, base_minus1, base_plus1; + base = slowTrig_N(WIDTH * k * s, WIDTH * SMALL_HEIGHT * k); + cmul_a_by_fancyb_and_conjfancyb(&base_plus1, &base_minus1, base, w); + WADD(k-1, base_minus1); + WADD(k, base); + WADD(k+1, base_plus1); + } + + WADDF(1, w); + + F2 w2; + if ((MIDDLE - 2) % 3 > 0) { + w2 = csqTrigFancy(w); + WADDF(2, w2); + } + + if ((MIDDLE - 2) % 3 == 2) { + F2 w3 = ccubeTrigFancy(w2, w); + WADDF(3, w3); + } + +#elif MM_CHAIN == 1 + for (u32 k = 3 + (MIDDLE - 2) % 3; k < MIDDLE; k += 3) { + F2 base, base_minus1, base_plus1; + base = slowTrig_N(WIDTH * k * s, WIDTH * SMALL_HEIGHT * k); + cmul_a_by_fancyb_and_conjfancyb(&base_plus1, &base_minus1, base, w); + WADD(k-1, base_minus1); + WADD(k, base); + WADD(k+1, base_plus1); + } + + WADDF(1, w); + + if ((MIDDLE - 2) % 3 > 0) { + WADDF(2, w); + WADDF(2, w); + } + + if ((MIDDLE - 2) % 3 == 2) { + WADDF(3, w); + WADDF(3, csqTrigFancy(w)); + } +#else +#error MM_CHAIN must be 0 or 1. +#endif + } +} + +void OVERLOAD middleMul2(F2 *u, u32 x, u32 y, float factor, TrigFP32 trig) { + assert(x < WIDTH); + assert(y < SMALL_HEIGHT); + + if (MIDDLE == 1) { + WADD(0, slowTrig_N(x * y, ND) * factor); + return; + } + + trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table + F2 w = trig[x]; // x / (MIDDLE * WIDTH) + + if (MIDDLE < SHARP_MIDDLE) { + F2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; + for (u32 k = 0; k < MIDDLE; ++k) { WADD(k, base); } + WSUB(0, w); + if (MIDDLE > 2) { WADD(2, w); } + if (MIDDLE > 3) { WADD(3, w); WADD(3, w); } + + } else { // MIDDLE >= 5 + // F2 w = slowTrig_N(x * SMALL_HEIGHT, ND / MIDDLE); + +#if 0 // Slower on Radeon 7, but proves the concept for use in GF61. Might be worthwhile on poor FP64 GPUs + + TrigFP32 trig2 = trig + WIDTH; // Skip over the fist MiddleMul2 trig table + u32 desired_root = x * y; + F2 base = cmulFancy(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]) * factor; //Optimization to do: put multiply by factor in trig2 table + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + +#elif AMDGPU && MM2_CHAIN == 0 // Oddly, Radeon 7 is faster with this version that uses more F64 ops + + F2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; + WADD(0, base); + WADD(1, base); + + for (u32 k = 2; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + WSUBF(0, w); + +#elif MM2_CHAIN == 0 + + u32 mid = MIDDLE / 2; + F2 base = slowTrig_N(x * y + x * SMALL_HEIGHT * mid, ND / MIDDLE * (mid + 1)) * factor; + WADD(mid, base); + + F2 basehi, baselo; + cmul_a_by_fancyb_and_conjfancyb(&basehi, &baselo, base, w); + WADD(mid-1, baselo); + WADD(mid+1, basehi); + + for (int i = mid-2; i >= 0; --i) { + baselo = cmulFancy(baselo, conjugate(w)); + WADD(i, baselo); + } + + for (int i = mid+2; i < MIDDLE; ++i) { + basehi = cmulFancy(basehi, w); + WADD(i, basehi); + } + +#elif MM2_CHAIN == 1 + u32 cnt = 1; + for (u32 start = 0, sz = (MIDDLE - start + cnt - 1) / cnt; cnt > 0; --cnt, start += sz) { + if (start + sz > MIDDLE) { --sz; } + u32 n = (sz - 1) / 2; + u32 mid = start + n; + + F2 base1 = slowTrig_N(x * y + x * SMALL_HEIGHT * mid, ND / MIDDLE * (mid + 1)) * factor; + WADD(mid, base1); + + F2 base2 = base1; + for (u32 i = 1; i <= n; ++i) { + base1 = cmulFancy(base1, conjugate(w)); + WADD(mid - i, base1); + + base2 = cmulFancy(base2, w); + WADD(mid + i, base2); + } + if (!(sz & 1)) { + base2 = cmulFancy(base2, w); + WADD(mid + n + 1, base2); + } + } + +#elif MM2_CHAIN == 2 + F2 base, base_minus1, base_plus1; + for (u32 i = 1; ; i += 3) { + if (i-1 == MIDDLE-1) { + base_minus1 = slowTrig_N(x * y + x * SMALL_HEIGHT * (i - 1), ND / MIDDLE * i) * factor; + WADD(i-1, base_minus1); + break; + } else if (i == MIDDLE-1) { + base_minus1 = slowTrig_N(x * y + x * SMALL_HEIGHT * (i - 1), ND / MIDDLE * i) * factor; + base = cmulFancy(base_minus1, w); + WADD(i-1, base_minus1); + WADD(i, base); + break; + } else { + base = slowTrig_N(x * y + x * SMALL_HEIGHT * i, ND / MIDDLE * (i + 1)) * factor; + cmul_a_by_fancyb_and_conjfancyb(&base_plus1, &base_minus1, base, w); + WADD(i-1, base_minus1); + WADD(i, base); + WADD(i+1, base_plus1); + if (i+1 == MIDDLE-1) break; + } + } +#else +#error MM2_CHAIN must be 0, 1 or 2. +#endif + } +} + +// Do a partial transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local F *lds, F2 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + if (MIDDLE <= 16) { + local F *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local F *p2 = lds + me; + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].y = p2[workgroupSize * i]; } + } +} // Do a partial transpose during fftMiddleIn/Out and write the results to global memory -void middleShuffleWrite(global T2 *out, T2 *u, u32 workgroupSize, u32 blockSize) { +void OVERLOAD middleShuffleWrite(global F2 *out, F2 *u, u32 workgroupSize, u32 blockSize) { u32 me = get_local_id(0); out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft2(GF31* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(GF31 *u) { +#if MIDDLE == 1 + // Do nothing +#elif MIDDLE == 2 + fft2(u); +#elif MIDDLE == 4 + fft4(u); +#elif MIDDLE == 8 + fft8(u); +#elif MIDDLE == 16 + fft16(u); +#else +#error UNRECOGNIZED MIDDLE +#endif +} + +void OVERLOAD middleMul(GF31 *u, u32 s, TrigGF31 trig) { + assert(s < SMALL_HEIGHT); + if (MIDDLE == 1) { return; } + + if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. + GF31 w = trig[s]; // s / BIG_HEIGHT + + WADD(1, w); + GF31 base = csq(w); + for (u32 k = 2; k < MIDDLE; ++k) { + WADD(k, base); + base = cmul(base, w); + } +} + +void OVERLOAD middleMul2(GF31 *u, u32 x, u32 y, TrigGF31 trig) { + assert(x < WIDTH); + assert(y < SMALL_HEIGHT); + + trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table + GF31 w = trig[x]; // x / (MIDDLE * WIDTH) + + TrigGF31 trig2 = trig + WIDTH; // Skip over first MiddleMul2 trig table + u32 desired_root = x * y; + GF31 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]); + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmul(base, w); + WADD(k, base); + } + +#if 0 // Might save a couple of muls with cmul_a_by_b_and_conjb if we can compute "desired_root = x * y + x * SMALL_HEIGHT" with a slightly expanded trig table + GF31 base = slowTrigGF31(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2); + WADD(1, base); + + if (MIDDLE == 2) { + WADD(0, base); + WSUB(0, w); + return; + } + + GF31 basehi, baselo; + cmul_a_by_b_and_conjb(&basehi, &baselo, base, w); + WADD(0, baselo); + WADD(2, basehi); + + for (int i = 3; i < MIDDLE; ++i) { + basehi = cmul(basehi, w); + WADD(i, basehi); + } +#endif +} + +// Do a partial transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local Z31 *lds, GF31 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + if (MIDDLE <= 16) { + local Z31 *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local Z31 *p2 = lds + me; + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].y = p2[workgroupSize * i]; } + } +} + +// Do a partial transpose during fftMiddleIn/Out and write the results to global memory +void OVERLOAD middleShuffleWrite(global GF31 *out, GF31 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft2(GF61* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(GF61 *u) { +#if MIDDLE == 1 + // Do nothing +#elif MIDDLE == 2 + fft2(u); +#elif MIDDLE == 4 + fft4(u); +#elif MIDDLE == 8 + fft8(u); +#elif MIDDLE == 16 + fft16(u); +#else +#error UNRECOGNIZED MIDDLE +#endif +} + +void OVERLOAD middleMul(GF61 *u, u32 s, TrigGF61 trig) { + assert(s < SMALL_HEIGHT); + if (MIDDLE == 1) { return; } + + if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. + GF61 w = trig[s]; // s / BIG_HEIGHT + + WADD(1, w); + GF61 base = csq(w); + for (u32 k = 2; k < MIDDLE; ++k) { + WADD(k, base); + base = cmul(base, w); + } +} + +void OVERLOAD middleMul2(GF61 *u, u32 x, u32 y, TrigGF61 trig) { + assert(x < WIDTH); + assert(y < SMALL_HEIGHT); + + trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table + GF61 w = trig[x]; // x / (MIDDLE * WIDTH) + + TrigGF61 trig2 = trig + WIDTH; // Skip over first MiddleMul2 trig table + u32 desired_root = x * y; + GF61 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]); + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmul(base, w); + WADD(k, base); + } + +#if 0 // Might save a couple of muls with cmul_a_by_b_and_conjb if we can compute "desired_root = x * y + x * SMALL_HEIGHT" with a slightly expanded trig table + GF61 base = slowTrigGF61(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2); + WADD(1, base); + + if (MIDDLE == 2) { + WADD(0, base); + WSUB(0, w); + return; + } + + GF61 basehi, baselo; + cmul_a_by_b_and_conjb(&basehi, &baselo, base, w); + WADD(0, baselo); + WADD(2, basehi); + + for (int i = 3; i < MIDDLE; ++i) { + basehi = cmul(basehi, w); + WADD(i, basehi); + } +#endif +} + +// Do a partial transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local Z61 *lds, GF61 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + if (MIDDLE <= 8) { + local Z61 *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local Z61 *p2 = lds + me; + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].y = p2[workgroupSize * i]; } + } else { + local int *p1 = ((local int*) lds) + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local int *p2 = (local int*) lds + me; + int4 *pu = (int4 *)u; + + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].y = p2[workgroupSize * i]; } + bar(); + + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].z; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].z = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].w; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].w = p2[workgroupSize * i]; } + } +} + +// Do a partial transpose during fftMiddleIn/Out and write the results to global memory +void OVERLOAD middleShuffleWrite(global GF61 *out, GF61 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } +} + +#endif + + +#undef WADD +#undef WADDF +#undef WSUB +#undef WSUBF diff --git a/src/cl/fft10.cl b/src/cl/fft10.cl index 0ede9c75..fe23a0b5 100644 --- a/src/cl/fft10.cl +++ b/src/cl/fft10.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda +#if FFT_FP64 + #include "fft5.cl" // PFA(5*2): 24 FMA + 68 ADD @@ -14,3 +16,5 @@ void fft10(T2 *u) { CYCLE(2, 4, 8, 6); #undef CYCLE } + +#endif diff --git a/src/cl/fft11.cl b/src/cl/fft11.cl index 5ac4f982..9bc12b13 100644 --- a/src/cl/fft11.cl +++ b/src/cl/fft11.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda & George Woltman +#if FFT_FP64 + #if 0 // Adapted from https://web.archive.org/web/20101126231320/http://cnx.org/content/col10569/1.7/pdf // 40 FMA + 150 ADD @@ -178,3 +180,5 @@ void fft11(T2 *u) { } #endif + +#endif diff --git a/src/cl/fft12.cl b/src/cl/fft12.cl index 5f74cf8e..6f400edc 100644 --- a/src/cl/fft12.cl +++ b/src/cl/fft12.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + #if 1 #include "fft3.cl" #include "fft4.cl" @@ -75,3 +77,5 @@ void fft12(T2 *u) { } #endif + +#endif diff --git a/src/cl/fft13.cl b/src/cl/fft13.cl index 6cc27d9c..528fdc04 100644 --- a/src/cl/fft13.cl +++ b/src/cl/fft13.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + // To calculate a 13-complex FFT in a brute force way (using a shorthand notation): // The sin/cos values (w = 13th root of unity) are: // w^1 = .885 - .465i @@ -155,3 +157,5 @@ void fft13(T2 *u) { fma_addsub(u[5], u[8], SIN1, tmp69a, tmp69b); fma_addsub(u[6], u[7], SIN1, tmp78a, tmp78b); } + +#endif diff --git a/src/cl/fft14.cl b/src/cl/fft14.cl index f786820c..8690cc74 100644 --- a/src/cl/fft14.cl +++ b/src/cl/fft14.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + #if 1 #include "fft7.cl" @@ -96,3 +98,5 @@ void fft14(T2 *u) { fma_addsub(u[6], u[8], SIN1, tmp79a, tmp79b); } #endif + +#endif diff --git a/src/cl/fft15.cl b/src/cl/fft15.cl index 3052f06c..beb8c895 100644 --- a/src/cl/fft15.cl +++ b/src/cl/fft15.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + // The fft3by() and fft5by() below use a different "output map" relative to fft3.cl and fft5.cl // This way fft15() does not need a "fix order" step at the end. // See "An In-Place, In-Order Prime Factor Algorithm" by Burrus & Eschenbacher (1981) @@ -73,3 +75,5 @@ void fft15(T2 *u) { fft5_15(u, 5); fft5_15(u, 10); } + +#endif diff --git a/src/cl/fft16.cl b/src/cl/fft16.cl index a07599ba..5603a25c 100644 --- a/src/cl/fft16.cl +++ b/src/cl/fft16.cl @@ -1,5 +1,6 @@ // Copyright (C) Mihai Preda +#if FFT_FP64 #if 0 @@ -8,7 +9,7 @@ #include "fft4.cl" // 24 FMA (of which 16 MUL) + 136 ADD -void fft16(T2 *u) { +void OVERLOAD fft16(T2 *u) { double C1 = 0.92387953251128674, // cos(tau/16) S1 = 0.38268343236508978; // sin(tau/16) @@ -43,7 +44,7 @@ void fft16(T2 *u) { #include "fft8.cl" -void fft16(T2 *u) { +void OVERLOAD fft16(T2 *u) { double C1 = 0.92387953251128674, // cos(tau/16) S1 = 0.38268343236508978; // sin(tau/16) @@ -77,7 +78,7 @@ void fft16(T2 *u) { // FFT-16 Adapted from Nussbaumer, "Fast Fourier Transform and Convolution Algorithms" // 28 FMA + 124 ADD -void fft16(T2 *u) { +void OVERLOAD fft16(T2 *u) { double C1 = 0.70710678118654757, // cos(2t/16) C2 = 0.38268343236508978, // cos(3t/16) @@ -152,3 +153,129 @@ void fft16(T2 *u) { } #endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +#include "fft8.cl" + +void OVERLOAD fft16(F2 *u) { + float + C1 = 0.92387953251128674, // cos(tau/16) + S1 = 0.38268343236508978; // sin(tau/16) + + for (int i = 0; i < 8; ++i) { X2(u[i], u[i + 8]); } + u[ 9] = cmul(u[ 9], U2( C1, S1)); // 1t16 + u[11] = cmul(u[11], U2( S1, C1)); // 3t16 + u[13] = cmul(u[13], U2(-S1, C1)); // 5t16 + u[15] = cmul(u[15], U2(-C1, S1)); // 7t16 + + u[10] = mul_t8(u[10]); + u[14] = mul_3t8(u[14]); + + u[12] = mul_t4(u[12]); + + fft8Core(u); + fft8Core(u + 8); + + // revbin fix order + // 0 8 4 12 2 10 6 14 1 9 5 13 3 11 7 15 + SWAP(u[1], u[8]); + SWAP(u[2], u[4]); + SWAP(u[3], u[12]); + SWAP(u[5], u[10]); + SWAP(u[7], u[14]); + SWAP(u[11], u[13]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +#include "fft8.cl" + +void OVERLOAD fft16(GF31 *u) { + const Z31 C1 = 1556715293; + const Z31 S1 = 978592373; + + X2(u[0], u[8]); + X2(u[1], u[9]); + X2_mul_t8(u[2], u[10]); + X2(u[3], u[11]); + X2_mul_t4(u[4], u[12]); + X2(u[5], u[13]); + X2_mul_3t8(u[6], u[14]); + X2(u[7], u[15]); + + u[ 9] = cmul(u[ 9], U2( C1, S1)); // 1t16 + u[11] = cmul(u[11], U2( S1, C1)); // 3t16 + u[13] = cmul(u[13], U2(neg(S1), C1)); // 5t16 //GWBUG - check if optimizer is eliminating the neg (or better yet perhaps tweak follow up code to expect a negative) + u[15] = cmul(u[15], U2(neg(C1), S1)); // 7t16 + + fft8Core(u); + fft8Core(u + 8); + + // revbin fix order + // 0 8 4 12 2 10 6 14 1 9 5 13 3 11 7 15 + SWAP(u[1], u[8]); + SWAP(u[2], u[4]); + SWAP(u[3], u[12]); + SWAP(u[5], u[10]); + SWAP(u[7], u[14]); + SWAP(u[11], u[13]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +#include "fft8.cl" + +void OVERLOAD fft16(GF61 *u) { + const Z61 C1 = 22027337052962166ULL; + const Z61 S1 = 1693317751237720973ULL; + + X2(u[0], u[8]); + X2(u[1], u[9]); + X2_mul_t8(u[2], u[10]); + X2(u[3], u[11]); + X2_mul_t4(u[4], u[12]); + X2(u[5], u[13]); + X2_mul_3t8(u[6], u[14]); + X2(u[7], u[15]); + + u[ 9] = cmul(u[ 9], U2( C1, S1)); // 1t16 + u[11] = cmul(u[11], U2( S1, C1)); // 3t16 + u[13] = cmul(u[13], U2(neg(S1), C1)); // 5t16 //GWBUG - check if optimizer is eliminating the neg (or better yet perhaps tweak follow up code to expect a negative) + u[15] = cmul(u[15], U2(neg(C1), S1)); // 7t16 + + fft8Core(u); + fft8Core(u + 8); + + // revbin fix order + // 0 8 4 12 2 10 6 14 1 9 5 13 3 11 7 15 + SWAP(u[1], u[8]); + SWAP(u[2], u[4]); + SWAP(u[3], u[12]); + SWAP(u[5], u[10]); + SWAP(u[7], u[14]); + SWAP(u[11], u[13]); +} + +#endif diff --git a/src/cl/fft3.cl b/src/cl/fft3.cl index cc501705..ee5b7af3 100644 --- a/src/cl/fft3.cl +++ b/src/cl/fft3.cl @@ -2,6 +2,8 @@ #pragma once +#if FFT_FP64 + // 6 FMA + 6 ADD void fft3by(T2 *u, u32 base, u32 step, u32 M) { #define A(k) u[(base + k * step) % M] @@ -28,3 +30,5 @@ void fft3by(T2 *u, u32 base, u32 step, u32 M) { } void fft3(T2 *u) { fft3by(u, 0, 1, 3); } + +#endif diff --git a/src/cl/fft4.cl b/src/cl/fft4.cl index 02d740f0..388d4aef 100644 --- a/src/cl/fft4.cl +++ b/src/cl/fft4.cl @@ -2,7 +2,9 @@ #pragma once -void fft4Core(T2 *u) { +#if FFT_FP64 + +void OVERLOAD fft4Core(T2 *u) { X2(u[0], u[2]); X2(u[1], u[3]); u[3] = mul_t4(u[3]); @@ -11,7 +13,7 @@ void fft4Core(T2 *u) { } // 16 ADD -void fft4by(T2 *u, u32 base, u32 step, u32 M) { +void OVERLOAD fft4by(T2 *u, u32 base, u32 step, u32 M) { #define A(k) u[(base + step * k) % M] @@ -59,4 +61,220 @@ void fft4by(T2 *u, u32 base, u32 step, u32 M) { } -void fft4(T2 *u) { fft4by(u, 0, 1, 4); } +void OVERLOAD fft4(T2 *u) { fft4by(u, 0, 1, 4); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft4Core(F2 *u) { + X2(u[0], u[2]); + X2(u[1], u[3]); u[3] = mul_t4(u[3]); + + X2(u[0], u[1]); + X2(u[2], u[3]); +} + +// 16 ADD +void OVERLOAD fft4by(F2 *u, u32 base, u32 step, u32 M) { + +#define A(k) u[(base + step * k) % M] + +#if 1 + float x0 = A(0).x + A(2).x; + float x2 = A(0).x - A(2).x; + float y0 = A(0).y + A(2).y; + float y2 = A(0).y - A(2).y; + + float x1 = A(1).x + A(3).x; + float y3 = A(1).x - A(3).x; + float y1 = A(1).y + A(3).y; + float x3 = -(A(1).y - A(3).y); + + float a0 = x0 + x1; + float a1 = x0 - x1; + + float b0 = y0 + y1; + float b1 = y0 - y1; + + float a2 = x2 + x3; + float a3 = x2 - x3; + + float b2 = y2 + y3; + float b3 = y2 - y3; + + A(0) = U2(a0, b0); + A(1) = U2(a2, b2); + A(2) = U2(a1, b1); + A(3) = U2(a3, b3); + +#else + + X2(A(0), A(2)); + X2(A(1), A(3)); + X2(A(0), A(1)); + + A(3) = mul_t4(A(3)); + X2(A(2), A(3)); + SWAP(A(1), A(2)); + +#endif + +#undef A + +} + +void OVERLOAD fft4(F2 *u) { fft4by(u, 0, 1, 4); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft4Core(GF31 *u) { + X2(u[0], u[2]); + X2(u[1], u[3]); u[3] = mul_t4(u[3]); + + X2(u[0], u[1]); + X2(u[2], u[3]); +} + +// 16 ADD +void OVERLOAD fft4by(GF31 *u, u32 base, u32 step, u32 M) { + +#define A(k) u[(base + step * k) % M] + + Z31 x0 = add(A(0).x, A(2).x); //GWBUG: Delay some of the mods (we have three spare bits) + Z31 x2 = sub(A(0).x, A(2).x); + Z31 y0 = add(A(0).y, A(2).y); + Z31 y2 = sub(A(0).y, A(2).y); + + Z31 x1 = add(A(1).x, A(3).x); + Z31 y3 = sub(A(1).x, A(3).x); + Z31 y1 = add(A(1).y, A(3).y); + Z31 x3 = sub(A(3).y, A(1).y); + + Z31 a0 = add(x0, x1); + Z31 a1 = sub(x0, x1); + + Z31 b0 = add(y0, y1); + Z31 b1 = sub(y0, y1); + + Z31 a2 = add(x2, x3); + Z31 a3 = sub(x2, x3); + + Z31 b2 = add(y2, y3); + Z31 b3 = sub(y2, y3); + + A(0) = U2(a0, b0); + A(1) = U2(a2, b2); + A(2) = U2(a1, b1); + A(3) = U2(a3, b3); + +#undef A + +} + +void OVERLOAD fft4(GF31 *u) { fft4by(u, 0, 1, 4); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft4Core(GF61 *u) { // Starts with all u[i] having maximum values of M61+epsilon. + X2q(&u[0], &u[2], 2); // X2(u[0], u[2]); No reductions mod M61. Will require 3 M61s additions to make positives. + X2q_mul_t4(&u[1], &u[3], 2); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2s(&u[0], &u[1], 3); + X2s(&u[2], &u[3], 3); +} + +// 16 ADD +void OVERLOAD fft4by(GF61 *u, u32 base, u32 step, u32 M) { + +#define A(k) u[(base + step * k) % M] + +#if !TEST_SHL + Z61 x0 = addq(A(0).x, A(2).x); // Max value is 2*M61+epsilon + Z61 x2 = subq(A(0).x, A(2).x, 2); // Max value is 3*M61+epsilon + Z61 y0 = addq(A(0).y, A(2).y); + Z61 y2 = subq(A(0).y, A(2).y, 2); + + Z61 x1 = addq(A(1).x, A(3).x); + Z61 y3 = subq(A(1).x, A(3).x, 2); + Z61 y1 = addq(A(1).y, A(3).y); + Z61 x3 = subq(A(3).y, A(1).y, 2); + + Z61 a0 = add(x0, x1); + Z61 a1 = subs(x0, x1, 3); + + Z61 b0 = add(y0, y1); + Z61 b1 = subs(y0, y1, 3); + + Z61 a2 = add(x2, x3); + Z61 a3 = subs(x2, x3, 4); + + Z61 b2 = add(y2, y3); + Z61 b3 = subs(y2, y3, 4); + + A(0) = U2(a0, b0); + A(1) = U2(a2, b2); + A(2) = U2(a1, b1); + A(3) = U2(a3, b3); + +#else // Test case to see if signed M61 mod would be faster (if so, look into creating X2q options in math.cl's GF61 to support signed intermediates) + + i64 x0 = A(0).x + A(2).x; + i64 x2 = A(0).x - A(2).x; + i64 y0 = A(0).y + A(2).y; + i64 y2 = A(0).y - A(2).y; + + i64 x1 = A(1).x + A(3).x; + i64 y3 = A(1).x - A(3).x; + i64 y1 = A(1).y + A(3).y; + i64 x3 = A(3).y - A(1).y; + + i64 a0 = x0 - x1; + i64 a1 = x0 - x1; + + i64 b0 = y0 + y1; + i64 b1 = y0 - y1; + + i64 a2 = x2 + x3; + i64 a3 = x2 - x3; + + i64 b2 = y2 + y3; + i64 b3 = y2 - y3; + +#define cvt(a) (Z61) ((a & M61) + (a >> MBITS)) + + A(0) = U2(cvt(a0), cvt(b0)); + A(1) = U2(cvt(a2), cvt(b2)); + A(2) = U2(cvt(a1), cvt(b1)); + A(3) = U2(cvt(a3), cvt(b3)); + +#undef cvt + +#endif + + +#undef A + +} + +void OVERLOAD fft4(GF61 *u) { fft4by(u, 0, 1, 4); } + +#endif diff --git a/src/cl/fft5.cl b/src/cl/fft5.cl index 2cf446b2..31e2ba9d 100644 --- a/src/cl/fft5.cl +++ b/src/cl/fft5.cl @@ -2,6 +2,8 @@ #pragma once +#if FFT_FP64 + // Adapted from: Nussbaumer, "Fast Fourier Transform and Convolution Algorithms", 5.5.4 "5-Point DFT". // 12 FMA + 24 ADD (or 10 FMA + 28 ADD) void fft5by(T2 *u, u32 base, u32 step, u32 m) { @@ -45,3 +47,5 @@ void fft5by(T2 *u, u32 base, u32 step, u32 m) { } void fft5(T2 *u) { return fft5by(u, 0, 1, 5); } + +#endif diff --git a/src/cl/fft6.cl b/src/cl/fft6.cl index d25466fe..54299286 100644 --- a/src/cl/fft6.cl +++ b/src/cl/fft6.cl @@ -2,6 +2,8 @@ #include "fft3.cl" +#if FFT_FP64 + // 12 FMA + 24 ADD void fft6(T2 *u) { #if 1 @@ -36,3 +38,5 @@ void fft6(T2 *u) { fma_addsub(u[2], u[4], -SIN1, tmp35a, u[2]); #endif } + +#endif diff --git a/src/cl/fft7.cl b/src/cl/fft7.cl index 1424148e..450cfa7d 100644 --- a/src/cl/fft7.cl +++ b/src/cl/fft7.cl @@ -4,6 +4,8 @@ #include "base.cl" +#if FFT_FP64 + #define A(i) u[(base + i * step) % M] #if 1 @@ -108,3 +110,5 @@ void fft7by(T2 *u, u32 base, u32 step, u32 M) { #undef A void fft7(T2 *u) { return fft7by(u, 0, 1, 7); } + +#endif diff --git a/src/cl/fft8.cl b/src/cl/fft8.cl index 9b113795..bf4332d7 100644 --- a/src/cl/fft8.cl +++ b/src/cl/fft8.cl @@ -4,31 +4,185 @@ #include "fft4.cl" +#if FFT_FP64 + T2 mul_t8_delayed(T2 a) { return U2(a.x - a.y, a.x + a.y); } T2 mul_3t8_delayed(T2 a) { return U2(-(a.x + a.y), a.x - a.y); } //#define X2_apply_delay(a, b) { T2 t = a; a = t + M_SQRT1_2 * b; b = t - M_SQRT1_2 * b; } #define X2_apply_delay(a, b) { T2 t = a; a.x = fma(b.x, M_SQRT1_2, a.x); a.y = fma(b.y, M_SQRT1_2, a.y); b.x = fma(-M_SQRT1_2, b.x, t.x); b.y = fma(-M_SQRT1_2, b.y, t.y); } -void fft4CoreSpecial(T2 *u) { +void OVERLOAD fft4CoreSpecial(T2 *u) { + X2(u[0], u[2]); + X2_mul_t4(u[1], u[3]); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2_apply_delay(u[0], u[1]); + X2_apply_delay(u[2], u[3]); +} + +void OVERLOAD fft8Core(T2 *u) { + X2(u[0], u[4]); + X2(u[1], u[5]); u[5] = mul_t8_delayed(u[5]); + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2(u[3], u[7]); u[7] = mul_3t8_delayed(u[7]); + fft4Core(u); + fft4CoreSpecial(u + 4); +} + +// 4 MUL + 52 ADD +void OVERLOAD fft8(T2 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F2 mul_t8_delayed(F2 a) { return U2(a.x - a.y, a.x + a.y); } +F2 mul_3t8_delayed(F2 a) { return U2(-(a.x + a.y), a.x - a.y); } +//#define X2_apply_delay(a, b) { F2 t = a; a = t + M_SQRT1_2 * b; b = t - M_SQRT1_2 * b; } +#define X2_apply_delay(a, b) { F2 t = a; a.x = fma(b.x, (float) M_SQRT1_2, a.x); a.y = fma(b.y, (float) M_SQRT1_2, a.y); b.x = fma((float) -M_SQRT1_2, b.x, t.x); b.y = fma((float) -M_SQRT1_2, b.y, t.y); } + +void OVERLOAD fft4CoreSpecial(F2 *u) { X2(u[0], u[2]); - X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2_mul_t4(u[1], u[3]); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); X2_apply_delay(u[0], u[1]); X2_apply_delay(u[2], u[3]); } -void fft8Core(T2 *u) { +void OVERLOAD fft8Core(F2 *u) { X2(u[0], u[4]); X2(u[1], u[5]); u[5] = mul_t8_delayed(u[5]); - X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); X2(u[3], u[7]); u[7] = mul_3t8_delayed(u[7]); fft4Core(u); fft4CoreSpecial(u + 4); } // 4 MUL + 52 ADD -void fft8(T2 *u) { +void OVERLOAD fft8(F2 *u) { fft8Core(u); // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo SWAP(u[1], u[4]); SWAP(u[3], u[6]); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft8Core(GF31 *u) { + X2(u[0], u[4]); //GWBUG: Delay some mods using extra 3 bits of Z61 + X2(u[1], u[5]); u[5] = mul_t8(u[5]); + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2(u[3], u[7]); u[7] = mul_3t8(u[7]); + fft4Core(u); + fft4Core(u + 4); +} + +// 4 MUL + 52 ADD +void OVERLOAD fft8(GF31 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +#if 0 // Working code. Fairly readable. + +// Same as mul_t8, but negation of a.y is delayed +GF61 OVERLOAD mul_t8_special(GF61 a) { return U2(shl(a.y + neg(a.x, 2), 30), shl(a.x + a.y, 30)); } +// Same as neg(a.y), X2_mul_t4(a, b) +void OVERLOAD X2_mul_t4_special(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = sub(b->y, a->y); t.x = sub(t.x, b->x); b->x = add(b->y, t.y); b->y = t.x; } + +void OVERLOAD fft4CoreSpecialU1(GF61 *u) { // u[1].y needs negation + X2(u[0], u[2]); + X2_mul_t4_special(&u[1], &u[3]); // u[1].y = -u[1].y; X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2(u[0], u[1]); + X2(u[2], u[3]); +} + +void OVERLOAD fft8Core(GF61 *u) { + X2(u[0], u[4]); //GWBUG: Delay some mods using extra 3 bits of Z61 + X2(u[1], u[5]); u[5] = mul_t8_special(u[5]); // u[5] = mul_t8(u[5]); But u[5].y needs negation + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2(u[3], u[7]); u[7] = mul_3t8(u[7]); + fft4Core(u); + fft4CoreSpecialU1(u + 4); +} + +// 4 MUL + 52 ADD +void OVERLOAD fft8(GF61 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#else // Carefully track the size of numbers to reduce the numberof mod M61 reductions + +// Same as mul_t8, but negation of a.y is delayed and a custom m61_count +GF61 OVERLOAD mul_t8_special(GF61 a, u32 m61_count) { return shl(U2(a.y + neg(a.x, m61_count), a.x + a.y), 30); } +// Same as mul_3t8, but with a custom m61_count +GF61 OVERLOAD mul_3t8_special(GF61 a, u32 m61_count) { return shl(U2(a.x + a.y, a.y + neg(a.x, m61_count)), 30); } +// Same as neg(a.y), X2q_mul_t4(a, b, m61_count) +void OVERLOAD X2q_mul_t4_special(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; a->x = a->x + b->x; a->y = b->y + neg(a->y, m61_count); t.x = t.x + neg(b->x, m61_count); b->x = b->y + t.y; b->y = t.x; } + +void OVERLOAD fft4CoreSpecial1(GF61 *u) { // Starts with u[0,1,2,3] having maximum values of (2,2,3,2)*M61+epsilon. + X2q(&u[0], &u[2], 4); // X2(u[0], u[2]); No reductions mod M61. u[0,2] max value is 5,6*M61+epsilon. + X2q_mul_t4(&u[1], &u[3], 3); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 5,4*M61+epsilon. + u[1] = mod(u[1]); u[2] = mod(u[2]); // Reduce the worst offenders. u[0,1,2,3] have maximum values of (5,1,1,4)*M61+epsilon. + X2s(&u[0], &u[1], 2); // u[0,1] max value before reduction is 6,7*M61+epsilon + X2s(&u[2], &u[3], 5); // u[2,3] max value before reduction is 5,6*M61+epsilon +} + +void OVERLOAD fft4CoreSpecial2(GF61 *u) { // Like above, u[1].y needs negation. Starts with u[0,1,2,3] having maximum values of (3,1,2,1)*M61+epsilon. + X2q(&u[0], &u[2], 3); // u[0,2] max value is 5,6*M61+epsilon. + X2q_mul_t4_special(&u[1], &u[3], 2); // u[1].y = -u[1].y; X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 3,2*M61+epsilon. + u[0] = mod(u[0]); u[2] = mod(u[2]); // Reduce the worst offenders u[0,1,2,3] have maximum values of (1,3,1,2)*M61+epsilon. + X2s(&u[0], &u[1], 4); // u[0,1] max value before reduction is 4,5*M61+epsilon + X2s(&u[2], &u[3], 3); // u[2,3] max value before reduction is 3,4*M61+epsilon +} + +void OVERLOAD fft8Core(GF61 *u) { // Starts with all u[i] having maximum values of M61+epsilon. + X2q(&u[0], &u[4], 2); // X2(u[0], u[4]); No reductions mod M61. u[0,4] max value is 2,3*M61+epsilon. + X2q(&u[1], &u[5], 2); // X2(u[1], u[5]); u[1,5] max value is 2,3*M61+epsilon. + u[5] = mul_t8_special(u[5], 4); // u[5] = mul_t8(u[5]); u[5].y needs neg. u[5] max value is 1*M61+epsilon. + X2q_mul_t4(&u[2], &u[6], 2); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); u[2,6] max value is 3,2*M61+epsilon. + X2q(&u[3], &u[7], 2); // X2(u[3], u[7]); u[3,7] max value is 2,3*M61+epsilon. + u[7] = mul_3t8_special(u[7], 4); // u[7] = mul_3t8(u[7]); u[7] max value is 1*M61+epsilon. + fft4CoreSpecial1(u); + fft4CoreSpecial2(u + 4); +} + +// 4 MUL + 52 ADD +void OVERLOAD fft8(GF61 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#endif + +#endif diff --git a/src/cl/fft9.cl b/src/cl/fft9.cl index a077d551..c980ec11 100644 --- a/src/cl/fft9.cl +++ b/src/cl/fft9.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + // Adapted from: Nussbaumer, "Fast Fourier Transform and Convolution Algorithms", 5.5.7 "9-Point DFT". // 12 FMA + 8 MUL, 72 ADD void fft9(T2 *u) { @@ -61,3 +63,5 @@ void fft9(T2 *u) { X2(u[2], u[7]); X2(u[1], u[8]); } + +#endif diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index 4e5f9584..3c80008e 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -5,7 +5,9 @@ #include "trig.cl" // #include "math.cl" -void chainMul4(T2 *u, T2 w) { +#if FFT_FP64 + +void OVERLOAD chainMul4(T2 *u, T2 w) { u[1] = cmul(u[1], w); T2 base = csqTrig(w); @@ -19,7 +21,7 @@ void chainMul4(T2 *u, T2 w) { #if 1 // This version of chainMul8 tries to minimize roundoff error even if more F64 ops are used. // Trial and error looking at Z values on a WIDTH=512 FFT was used to determine when to switch from fancy to non-fancy powers of w. -void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { +void OVERLOAD chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { u[1] = cmulFancy(u[1], w); T2 w2; @@ -57,7 +59,7 @@ void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { // This version is slower on R7Pro due to a rocm optimizer issue in double-wide single-kernel tailSquare using BCAST. I could not find a work-around. // Other GPUs??? This version might be useful. If we decide to make this available, it will need a new width and height fft spec number. // Consequently, an increase in the BPW table and increase work for -ztune and -tune. -void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { +void OVERLOAD chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { u[1] = cmulFancy(u[1], w); T2 w2 = csqTrigFancy(w); @@ -77,7 +79,7 @@ void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { } #endif -void chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { +void OVERLOAD chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { // Do a length 4 chain mul, w must not be in Fancy format if (len == 4) chainMul4(u, w); // Do a length 8 chain mul, w must be in Fancy format @@ -104,7 +106,7 @@ T2 bcast(T2 src, u32 span) { #endif -void shuflBigLDS(u32 WG, local T2 *lds, T2 *u, u32 n, u32 f) { +void OVERLOAD shuflBigLDS(u32 WG, local T2 *lds, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); u32 mask = f - 1; assert((mask & (mask + 1)) == 0); @@ -114,7 +116,7 @@ void shuflBigLDS(u32 WG, local T2 *lds, T2 *u, u32 n, u32 f) { for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } } -void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { +void OVERLOAD shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); local T* lds = (local T*) lds2; @@ -133,7 +135,7 @@ void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { // Lower LDS requirements should let the optimizer use fewer VGPRs and increase occupancy for WIDTHs >= 1024. // Alas, the increased occupancy does not offset extra code needed for shufl_int (the assembly // code generated is not pretty). This might not be true for nVidia or future ROCm optimizers. -void shufl_int(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { +void OVERLOAD shufl_int(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); local int* lds = (local int*) lds2; @@ -161,7 +163,7 @@ void shufl_int(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { // NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. // Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half // of lds memory (first SMALL_HEIGHT T2 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT T2 values). -void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { +void OVERLOAD shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); // Partition lds memory into upper and lower halves @@ -173,7 +175,7 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { me = me % WG; u32 mask = f - 1; assert((mask & (mask + 1)) == 0); - + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } bar(WG); for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } @@ -183,7 +185,7 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } } -void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { +void OVERLOAD tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { #if 0 u32 p = me / f * f; #else @@ -470,3 +472,373 @@ void finish_tabMul8_fft8(u32 WG, local T2 *lds, Trig trig, T *preloads, T2 *u, u SWAP(u[1], u[4]); SWAP(u[3], u[6]); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD chainMul4(F2 *u, F2 w) { + u[1] = cmul(u[1], w); + + F2 base = csqTrig(w); + u[2] = cmul(u[2], base); + + F a = mul2(base.y); + base = U2(fma(a, -w.y, w.x), fma(a, w.x, -w.y)); + u[3] = cmul(u[3], base); +} + +void OVERLOAD chainMul8(F2 *u, F2 w, u32 tailSquareBcast) { + u[1] = cmulFancy(u[1], w); + //GWBUG - see FP64 version for many possible optimizations + F2 w2 = csqTrigFancy(w); + u[2] = cmulFancy(u[2], w2); + + F2 w3 = ccubeTrigFancy(w2, w); + u[3] = cmulFancy(u[3], w3); + + w3.x += 1; + F2 base = cmulFancy (w3, w); + for (int i = 4; i < 8; ++i) { + u[i] = cmul(u[i], base); + base = cmulFancy(base, w); + } +} + +void OVERLOAD chainMul(u32 len, F2 *u, F2 w, u32 tailSquareBcast) { + // Do a length 4 chain mul + if (len == 4) chainMul4(u, w); + // Do a length 8 chain mul + if (len == 8) chainMul8(u, w, tailSquareBcast); +} + +void OVERLOAD shuflBigLDS(u32 WG, local F2 *lds, F2 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i]; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD shufl(u32 WG, local F2 *lds2, F2 *u, u32 n, u32 f) { //GWBUG - is shufl of int2 faster (BigLDS)? + u32 me = get_local_id(0); + local F* lds = (local F*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads. +// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. +// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half +// of lds memory (first SMALL_HEIGHT GF31 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT GF31 values). +void OVERLOAD shufl2(u32 WG, local F2 *lds2, F2 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + + // Partition lds memory into upper and lower halves + assert(WG == G_H); + + // Accessing lds memory as F is faster than F2 accesses //GWBUG??? + local F* lds = ((local F*) lds2) + (me / WG) * SMALL_HEIGHT; + + me = me % WG; + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(WG); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +void OVERLOAD tabMul(u32 WG, TrigFP32 trig, F2 *u, u32 n, u32 f, u32 me) { + u32 p = me & ~(f - 1); + +// This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. + + if (TABMUL_CHAIN) { + chainMul (n, u, trig[p], 0); + return; + } + +// Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. + + if (!TABMUL_CHAIN) { + if (n >= 8) { + u[1] = cmulFancy(u[1], trig[p]); + } else { + u[1] = cmul(u[1], trig[p]); + } + for (u32 i = 2; i < n; ++i) { + u[i] = cmul(u[i], trig[(i-1)*WG + p]); + } + return; + } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD chainMul4(GF31 *u, GF31 w) { + u[1] = cmul(u[1], w); + + GF31 base = csq(w); + u[2] = cmul(u[2], base); + + base = cmul(base, w); //GWBUG - see FP64 version for possible optimization + u[3] = cmul(u[3], base); +} + +void OVERLOAD chainMul8(GF31 *u, GF31 w, u32 tailSquareBcast) { + u[1] = cmul(u[1], w); + + GF31 w2 = csq(w); + u[2] = cmul(u[2], w2); + + GF31 base = cmul (w2, w); //GWBUG - see FP64 version for many possible optimizations + for (int i = 3; i < 8; ++i) { + u[i] = cmul(u[i], base); + base = cmul(base, w); + } +} + +void OVERLOAD chainMul(u32 len, GF31 *u, GF31 w, u32 tailSquareBcast) { + // Do a length 4 chain mul + if (len == 4) chainMul4(u, w); + // Do a length 8 chain mul + if (len == 8) chainMul8(u, w, tailSquareBcast); +} + +void OVERLOAD shuflBigLDS(u32 WG, local GF31 *lds, GF31 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i]; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD shufl(u32 WG, local GF31 *lds2, GF31 *u, u32 n, u32 f) { //GWBUG - is shufl of int2 faster (BigLDS)? + u32 me = get_local_id(0); + local Z31* lds = (local Z31*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads. +// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. +// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half +// of lds memory (first SMALL_HEIGHT GF31 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT GF31 values). +void OVERLOAD shufl2(u32 WG, local GF31 *lds2, GF31 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + + // Partition lds memory into upper and lower halves + assert(WG == G_H); + + // Accessing lds memory as Z31s is faster than GF31 accesses //GWBUG??? + local Z31* lds = ((local Z31*) lds2) + (me / WG) * SMALL_HEIGHT; + + me = me % WG; + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(WG); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +void OVERLOAD tabMul(u32 WG, TrigGF31 trig, GF31 *u, u32 n, u32 f, u32 me) { + u32 p = me & ~(f - 1); + +// This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. + + if (TABMUL_CHAIN) { + chainMul (n, u, trig[p], 0); + return; + } + +// Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. + + if (!TABMUL_CHAIN) { + for (u32 i = 1; i < n; ++i) { + u[i] = cmul(u[i], trig[(i-1)*WG + p]); + } + return; + } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD chainMul4(GF61 *u, GF61 w) { + u[1] = cmul(u[1], w); + + GF61 base = csq(w); + u[2] = cmul(u[2], base); + + base = cmul(base, w); //GWBUG - see FP64 version for possible optimization + u[3] = cmul(u[3], base); +} + +void OVERLOAD chainMul8(GF61 *u, GF61 w, u32 tailSquareBcast) { + u[1] = cmul(u[1], w); + + GF61 w2 = csq(w); + u[2] = cmul(u[2], w2); + + GF61 base = cmul (w2, w); //GWBUG - see FP64 version for many possible optimizations + for (int i = 3; i < 8; ++i) { + u[i] = cmul(u[i], base); + base = cmul(base, w); + } +} + +void OVERLOAD chainMul(u32 len, GF61 *u, GF61 w, u32 tailSquareBcast) { + // Do a length 4 chain mul + if (len == 4) chainMul4(u, w); + // Do a length 8 chain mul + if (len == 8) chainMul8(u, w, tailSquareBcast); +} + +void OVERLOAD shuflBigLDS(u32 WG, local GF61 *lds, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i]; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD shufl(u32 WG, local GF61 *lds2, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + local Z61* lds = (local Z61*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +// Same as shufl but use ints instead of Z61s to reduce LDS memory requirements. +// Lower LDS requirements should let the optimizer use fewer VGPRs and increase occupancy for WIDTHs >= 1024. +// Alas, the increased occupancy does not offset extra code needed for shufl_int (the assembly +// code generated is not pretty). This might not be true for nVidia or future ROCm optimizers. +void OVERLOAD shufl_int(u32 WG, local GF61 *lds2, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + local int* lds = (local int*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).x; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.x = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).y; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.y = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).z; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.z = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).w; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.w = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); // I'm not sure why this barrier call is needed +} + +// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads. +// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. +// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half +// of lds memory (first SMALL_HEIGHT GF61 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT GF61 values). +void OVERLOAD shufl2(u32 WG, local GF61 *lds2, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + + // Partition lds memory into upper and lower halves + assert(WG == G_H); + + // Accessing lds memory as Z61s is faster than GF61 accesses + local Z61* lds = ((local Z61*) lds2) + (me / WG) * SMALL_HEIGHT; + + me = me % WG; + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(WG); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +void OVERLOAD tabMul(u32 WG, TrigGF61 trig, GF61 *u, u32 n, u32 f, u32 me) { + u32 p = me & ~(f - 1); + +// This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. + + if (TABMUL_CHAIN) { + chainMul (n, u, trig[p], 0); + return; + } + +// Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. + + if (!TABMUL_CHAIN) { + for (u32 i = 1; i < n; ++i) { + u[i] = cmul(u[i], trig[(i-1)*WG + p]); + } + return; + } +} + +#endif diff --git a/src/cl/fftheight.cl b/src/cl/fftheight.cl index 2ceb9326..a2f67cb2 100644 --- a/src/cl/fftheight.cl +++ b/src/cl/fftheight.cl @@ -10,7 +10,9 @@ u32 transPos(u32 k, u32 middle, u32 width) { return k / width + k % width * middle; } -void fft_NH(T2 *u) { +#if FFT_FP64 + +void OVERLOAD fft_NH(T2 *u) { #if NH == 4 fft4(u); #elif NH == 8 @@ -29,7 +31,7 @@ void fft_NH(T2 *u) { #error FFT_VARIANT_H == 0 only supported by AMD GPUs #endif -void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { if (s > 1) { bar(); } fft_NH(u); @@ -42,7 +44,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { fft_NH(u); } -void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { u32 WG = SMALL_HEIGHT / NH; for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { if (s > 1) { bar(WG); } @@ -58,7 +60,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { #else -void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { u32 me = get_local_id(0); #if !UNROLL_H @@ -74,7 +76,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { fft_NH(u); } -void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { u32 me = get_local_id(0); u32 WG = SMALL_HEIGHT / NH; @@ -93,9 +95,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { #endif - - -void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { +void OVERLOAD new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { u32 WG = SMALL_HEIGHT / NH; u32 me = get_local_id(0); // This line mimics shufl2 -- partition lds into halves @@ -205,3 +205,169 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { void new_fft_HEIGHT2_1(local T2 *lds, T2 *u, Trig trig, T2 w) { new_fft_HEIGHT2(lds, u, trig, w, 1); } void new_fft_HEIGHT2_2(local T2 *lds, T2 *u, Trig trig, T2 w) { new_fft_HEIGHT2(lds, u, trig, w, 2); } +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft_NH(F2 *u) { +#if NH == 4 + fft4(u); +#elif NH == 8 + fft8(u); +#else +#error NH +#endif +} + +void OVERLOAD fft_HEIGHT(local F2 *lds, F2 *u, TrigFP32 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { + if (s > 1) { bar(); } + fft_NH(u); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + shufl(SMALL_HEIGHT / NH, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD fft_HEIGHT2(local F2 *lds, F2 *u, TrigFP32 trig) { + u32 me = get_local_id(0); + u32 WG = SMALL_HEIGHT / NH; + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < WG; s *= NH) { + if (s > 1) { bar(WG); } + fft_NH(u); + tabMul(WG, trig, u, NH, s, me % WG); + shufl2(WG, lds, u, NH, s); + } + fft_NH(u); +} + +void new_fft_HEIGHT2_1(local F2 *lds, F2 *u, TrigFP32 trig) { fft_HEIGHT2(lds, u, trig); } +void new_fft_HEIGHT2_2(local F2 *lds, F2 *u, TrigFP32 trig) { fft_HEIGHT2(lds, u, trig); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft_NH(GF31 *u) { +#if NH == 4 + fft4(u); +#elif NH == 8 + fft8(u); +#else +#error NH +#endif +} + +void OVERLOAD fft_HEIGHT(local GF31 *lds, GF31 *u, TrigGF31 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { + if (s > 1) { bar(); } + fft_NH(u); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + shufl(SMALL_HEIGHT / NH, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD fft_HEIGHT2(local GF31 *lds, GF31 *u, TrigGF31 trig) { + u32 me = get_local_id(0); + u32 WG = SMALL_HEIGHT / NH; + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < WG; s *= NH) { + if (s > 1) { bar(WG); } + fft_NH(u); + tabMul(WG, trig, u, NH, s, me % WG); + shufl2(WG, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD new_fft_HEIGHT2_1(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_HEIGHT2(lds, u, trig); } +void OVERLOAD new_fft_HEIGHT2_2(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_HEIGHT2(lds, u, trig); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft_NH(GF61 *u) { +#if NH == 4 + fft4(u); +#elif NH == 8 + fft8(u); +#else +#error NH +#endif +} + +void OVERLOAD fft_HEIGHT(local GF61 *lds, GF61 *u, TrigGF61 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { + if (s > 1) { bar(); } + fft_NH(u); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + shufl(SMALL_HEIGHT / NH, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD fft_HEIGHT2(local GF61 *lds, GF61 *u, TrigGF61 trig) { + u32 me = get_local_id(0); + u32 WG = SMALL_HEIGHT / NH; + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < WG; s *= NH) { + if (s > 1) { bar(WG); } + fft_NH(u); + tabMul(WG, trig, u, NH, s, me % WG); + shufl2(WG, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD new_fft_HEIGHT2_1(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_HEIGHT2(lds, u, trig); } +void OVERLOAD new_fft_HEIGHT2_2(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_HEIGHT2(lds, u, trig); } + +#endif diff --git a/src/cl/ffthin.cl b/src/cl/ffthin.cl index a869a2a0..b579132b 100644 --- a/src/cl/ffthin.cl +++ b/src/cl/ffthin.cl @@ -4,7 +4,9 @@ #include "math.cl" #include "fftheight.cl" -// Do an FFT Height after a transposeW (which may not have fully transposed data, leading to non-sequential input) +#if FFT_FP64 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { local T2 lds[SMALL_HEIGHT / 2]; @@ -25,3 +27,97 @@ KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { write(G_H, NH, u, out, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) +KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT / 2]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) smallTrig; + + F2 u[NH]; + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + readTailFusedLine(inF2, u, g, me); + +#if NH == 8 + F2 w = fancyTrig_N(ND / SMALL_HEIGHT * me); +#else + F2 w = slowTrig_N(ND / SMALL_HEIGHT * me, ND / NH); +#endif + + fft_HEIGHT(lds, u, smallTrigF2, w); + + write(G_H, NH, u, outF2, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) +KERNEL(G_H) fftHinGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT / 2]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH]; + u32 g = get_group_id(0); + + u32 me = get_local_id(0); + + readTailFusedLine(in31, u, g, me); + + fft_HEIGHT(lds, u, smallTrig31); + + write(G_H, NH, u, out31, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) +KERNEL(G_H) fftHinGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT / 2]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH]; + u32 g = get_group_id(0); + + u32 me = get_local_id(0); + + readTailFusedLine(in61, u, g, me); + + fft_HEIGHT(lds, u, smallTrig61); + + write(G_H, NH, u, out61, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); +} + +#endif diff --git a/src/cl/fftmiddlein.cl b/src/cl/fftmiddlein.cl index 830fa2b3..18b0835b 100644 --- a/src/cl/fftmiddlein.cl +++ b/src/cl/fftmiddlein.cl @@ -5,13 +5,15 @@ #include "fft-middle.cl" #include "middle.cl" +#if FFT_FP64 + KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { T2 u[MIDDLE]; - + u32 SIZEY = IN_WG / IN_SIZEX; u32 N = WIDTH / IN_SIZEX; - + u32 g = get_group_id(0); u32 gx = g % N; u32 gy = g / N; @@ -46,3 +48,170 @@ KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { writeMiddleInLine(out, u, gy, gx); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { + F2 u[MIDDLE]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 SIZEY = IN_WG / IN_SIZEX; + + u32 N = WIDTH / IN_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % IN_SIZEX; + u32 my = me / IN_SIZEX; + + u32 startx = gx * IN_SIZEX; + u32 starty = gy * SIZEY; + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleInLine(u, inF2, y, x); + + middleMul2(u, x, y, 1, trigF2); + + fft_MIDDLE(u); + + middleMul(u, y, trigF2); + +#if MIDDLE_IN_LDS_TRANSPOSE + // Transpose the x and y values + local F lds[IN_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, IN_WG, IN_SIZEX); + outF2 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + outF2 += mx * SIZEY + my; +#endif + + writeMiddleInLine(outF2, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(IN_WG) fftMiddleInGF31(P(T2) out, CP(T2) in, Trig trig) { + GF31 u[MIDDLE]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 SIZEY = IN_WG / IN_SIZEX; + + u32 N = WIDTH / IN_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % IN_SIZEX; + u32 my = me / IN_SIZEX; + + u32 startx = gx * IN_SIZEX; + u32 starty = gy * SIZEY; + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleInLine(u, in31, y, x); + + middleMul2(u, x, y, trig31); + + fft_MIDDLE(u); + + middleMul(u, y, trig31); + +#if MIDDLE_IN_LDS_TRANSPOSE + // Transpose the x and y values + local Z31 lds[IN_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, IN_WG, IN_SIZEX); + out31 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out31 += mx * SIZEY + my; +#endif + + writeMiddleInLine(out31, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(IN_WG) fftMiddleInGF61(P(T2) out, CP(T2) in, Trig trig) { + GF61 u[MIDDLE]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 SIZEY = IN_WG / IN_SIZEX; + + u32 N = WIDTH / IN_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % IN_SIZEX; + u32 my = me / IN_SIZEX; + + u32 startx = gx * IN_SIZEX; + u32 starty = gy * SIZEY; + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleInLine(u, in61, y, x); + + middleMul2(u, x, y, trig61); + + fft_MIDDLE(u); + + middleMul(u, y, trig61); + +#if MIDDLE_IN_LDS_TRANSPOSE + // Transpose the x and y values + local Z61 lds[IN_WG / 2 * (MIDDLE <= 8 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, IN_WG, IN_SIZEX); + out61 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out61 += mx * SIZEY + my; +#endif + + writeMiddleInLine(out61, u, gy, gx); +} + +#endif diff --git a/src/cl/fftmiddleout.cl b/src/cl/fftmiddleout.cl index 6404c7ae..43e3e48f 100644 --- a/src/cl/fftmiddleout.cl +++ b/src/cl/fftmiddleout.cl @@ -5,6 +5,8 @@ #include "fft-middle.cl" #include "middle.cl" +#if FFT_FP64 + KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { T2 u[MIDDLE]; @@ -56,3 +58,188 @@ KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { writeMiddleOutLine(out, u, gy, gx); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { + F2 u[MIDDLE]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 SIZEY = OUT_WG / OUT_SIZEX; + + u32 N = SMALL_HEIGHT / OUT_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % OUT_SIZEX; + u32 my = me / OUT_SIZEX; + + // Kernels read OUT_SIZEX consecutive T2. + // Each WG-thread kernel processes OUT_SIZEX columns from a needed SMALL_HEIGHT columns + // Each WG-thread kernel processes SIZEY rows out of a needed WIDTH rows + + u32 startx = gx * OUT_SIZEX; // Each input column increases FFT element by one + u32 starty = gy * SIZEY; // Each input row increases FFT element by BIG_HEIGHT + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleOutLine(u, inF2, y, x); + + middleMul(u, x, trigF2); + + fft_MIDDLE(u); + + // FFT results come out multiplied by the FFT length (NWORDS). Also, for performance reasons + // weights and invweights are doubled meaning we need to divide by another 2^2 and 2^2. + // Finally, roundoff errors are sometimes improved if we use the next lower double precision + // number. This may be due to roundoff errors introduced by applying inexact TWO_TO_N_8TH weights. + double factor = 1.0 / (4 * 4 * NWORDS); + + middleMul2(u, y, x, factor, trigF2); + +#if MIDDLE_OUT_LDS_TRANSPOSE + // Transpose the x and y values + local F lds[OUT_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, OUT_WG, OUT_SIZEX); + outF2 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + outF2 += mx * SIZEY + my; +#endif + + writeMiddleOutLine(outF2, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(OUT_WG) fftMiddleOutGF31(P(T2) out, P(T2) in, Trig trig) { + GF31 u[MIDDLE]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 SIZEY = OUT_WG / OUT_SIZEX; + + u32 N = SMALL_HEIGHT / OUT_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % OUT_SIZEX; + u32 my = me / OUT_SIZEX; + + // Kernels read OUT_SIZEX consecutive T2. + // Each WG-thread kernel processes OUT_SIZEX columns from a needed SMALL_HEIGHT columns + // Each WG-thread kernel processes SIZEY rows out of a needed WIDTH rows + + u32 startx = gx * OUT_SIZEX; // Each input column increases FFT element by one + u32 starty = gy * SIZEY; // Each input row increases FFT element by BIG_HEIGHT + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleOutLine(u, in31, y, x); + + middleMul(u, x, trig31); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig31); + +#if MIDDLE_OUT_LDS_TRANSPOSE + // Transpose the x and y values + local Z31 lds[OUT_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, OUT_WG, OUT_SIZEX); + out31 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out31 += mx * SIZEY + my; +#endif + + writeMiddleOutLine(out31, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(OUT_WG) fftMiddleOutGF61(P(T2) out, P(T2) in, Trig trig) { + GF61 u[MIDDLE]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 SIZEY = OUT_WG / OUT_SIZEX; + + u32 N = SMALL_HEIGHT / OUT_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % OUT_SIZEX; + u32 my = me / OUT_SIZEX; + + // Kernels read OUT_SIZEX consecutive T2. + // Each WG-thread kernel processes OUT_SIZEX columns from a needed SMALL_HEIGHT columns + // Each WG-thread kernel processes SIZEY rows out of a needed WIDTH rows + + u32 startx = gx * OUT_SIZEX; // Each input column increases FFT element by one + u32 starty = gy * SIZEY; // Each input row increases FFT element by BIG_HEIGHT + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleOutLine(u, in61, y, x); + + middleMul(u, x, trig61); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig61); + +#if MIDDLE_OUT_LDS_TRANSPOSE + // Transpose the x and y values + local Z61 lds[OUT_WG / 2 * (MIDDLE <= 8 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, OUT_WG, OUT_SIZEX); + out61 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out61 += mx * SIZEY + my; +#endif + + writeMiddleOutLine(out61, u, gy, gx); +} + +#endif diff --git a/src/cl/fftp.cl b/src/cl/fftp.cl index 70a707d8..9e59db30 100644 --- a/src/cl/fftp.cl +++ b/src/cl/fftp.cl @@ -6,28 +6,467 @@ #include "fftwidth.cl" #include "middle.cl" +#if FFT_FP64 & !COMBO_FFT + // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) { local T2 lds[WIDTH / 2]; - T2 u[NW]; + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + T base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + for (u32 i = 0; i < NW; ++i) { + T w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + T w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 p = G_W * i + me; + u[i] = U2(in[p].x * w1, in[p].y * w2); + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_FP32 & !COMBO_FFT + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(F2) out, CP(Word2) in, TrigFP32 smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local F2 lds[WIDTH / 2]; + F2 u[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + for (u32 i = 0; i < NW; ++i) { + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 p = G_W * i + me; + u[i] = U2(in[p].x * w1, in[p].y * w2); + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & !COMBO_FFT + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(GF31) out, CP(Word2) in, TrigGF31 smallTrig) { + local GF31 lds[WIDTH / 2]; + GF31 u[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Convert and weight inputs + u[i] = U2(shl(make_Z31(in[p].x), weight_shift0), shl(make_Z31(in[p].y), weight_shift1)); // Form a GF31 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF61 & !COMBO_FFT + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(GF61) out, CP(Word2) in, TrigGF61 smallTrig) { + local GF61 lds[WIDTH / 2]; + GF61 u[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Convert and weight input + u[i] = U2(shl(make_Z61(in[p].x), weight_shift0), shl(make_Z61(in[p].y), weight_shift1)); // Form a GF61 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} - u32 step = WIDTH * g; - in += step; +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP64 & NTT_GF31 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) { + local T2 lds[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) lds; + T2 u[NW]; + GF31 u31[NW]; + + u32 g = get_group_id(0); u32 me = get_local_id(0); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + + in += g * WIDTH; + T base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP64 weights and the second GF31 weight shift T w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); T w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); - u32 p = G_W * i + me; - u[i] = U2(in[p].x, in[p].y) * U2(w1, w2); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Convert and weight input + u[i] = U2(in[p].x * w1, in[p].y * w2); + u31[i] = U2(shl(make_Z31(in[p].x), weight_shift0), shl(make_Z31(in[p].y), weight_shift1)); // Form a GF31 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; } fft_WIDTH(lds, u, smallTrig); - writeCarryFusedLine(u, out, g); + bar(); + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); } + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local F2 ldsF2[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) ldsF2; + F2 uF2[NW]; + GF31 u31[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP32 weights and the second GF31 weight shift + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Convert and weight input + uF2[i] = U2(in[p].x * w1, in[p].y * w2); + u31[i] = U2(shl(make_Z31(in[p].x), weight_shift0), shl(make_Z31(in[p].y), weight_shift1)); // Form a GF31 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + fft_WIDTH(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, g); + bar(); + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_FP32 & NTT_GF61 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local GF61 lds61[WIDTH / 2]; + local F2 *ldsF2 = (local F2 *) lds61; + F2 uF2[NW]; + GF61 u61[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP32 weights and the second GF61 weight shift + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Convert and weight input + uF2[i] = U2(in[p].x * w1, in[p].y * w2); + u61[i] = U2(shl(make_Z61(in[p].x), weight_shift0), shl(make_Z61(in[p].y), weight_shift1)); // Form a GF61 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + fft_WIDTH(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, g); + bar(); + fft_WIDTH(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif NTT_GF31 & NTT_GF61 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { + local GF61 lds61[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) lds61; + GF31 u31[NW]; + GF61 u61[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + in += g * WIDTH; + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m31_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m61_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + // Convert and weight input + u31[i] = U2(shl(make_Z31(in[p].x), m31_weight_shift0), shl(make_Z31(in[p].y), m31_weight_shift1)); // Form a GF31 from each pair of input words + u61[i] = U2(shl(make_Z61(in[p].x), m61_weight_shift0), shl(make_Z61(in[p].y), m61_weight_shift1)); // Form a GF61 from each pair of input words + +// Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_bigstep; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + } + + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); + bar(); + fft_WIDTH(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, g); +} + + +#else +error - missing FFTp kernel implementation +#endif diff --git a/src/cl/fftw.cl b/src/cl/fftw.cl index afefc547..a19b26d0 100644 --- a/src/cl/fftw.cl +++ b/src/cl/fftw.cl @@ -5,10 +5,12 @@ #include "fftwidth.cl" #include "middle.cl" -// Do an fft_WIDTH after a transposeH (which may not have fully transposed data, leading to non-sequential input) +#if FFT_FP64 + +// Do the ending fft_WIDTH after an fftMiddleOut. This is the same as the first half of carryFused. KERNEL(G_W) fftW(P(T2) out, CP(T2) in, Trig smallTrig) { local T2 lds[WIDTH / 2]; - + T2 u[NW]; u32 g = get_group_id(0); @@ -17,3 +19,81 @@ KERNEL(G_W) fftW(P(T2) out, CP(T2) in, Trig smallTrig) { out += WIDTH * g; write(G_W, NW, u, out, 0); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Do the ending fft_WIDTH after an fftMiddleOut. This is the same as the first half of carryFused. +KERNEL(G_W) fftW(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[WIDTH / 2]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NW]; + u32 g = get_group_id(0); + + readCarryFusedLine(inF2, u, g); + fft_WIDTH(lds, u, smallTrigF2); + outF2 += WIDTH * g; + write(G_W, NW, u, outF2, 0); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(G_W) fftWGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[WIDTH / 2]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + + GF31 u[NW]; + u32 g = get_group_id(0); + + readCarryFusedLine(in31, u, g); + fft_WIDTH(lds, u, smallTrig31); + out31 += WIDTH * g; + write(G_W, NW, u, out31, 0); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(G_W) fftWGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[WIDTH / 2]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + GF61 u[NW]; + u32 g = get_group_id(0); + + readCarryFusedLine(in61, u, g); + fft_WIDTH(lds, u, smallTrig61); + out61 += WIDTH * g; + write(G_W, NW, u, out61, 0); +} + +#endif diff --git a/src/cl/fftwidth.cl b/src/cl/fftwidth.cl index 92cfc47d..d0589ab0 100644 --- a/src/cl/fftwidth.cl +++ b/src/cl/fftwidth.cl @@ -6,7 +6,9 @@ #error WIDTH must be one of: 256, 512, 1024, 4096, 625 #endif -void fft_NW(T2 *u) { +#if FFT_FP64 + +void OVERLOAD fft_NW(T2 *u) { #if NW == 4 fft4(u); #elif NW == 5 @@ -27,7 +29,7 @@ void fft_NW(T2 *u) { #error FFT_VARIANT_W == 0 only supported by AMD GPUs #endif -void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { +void OVERLOAD fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { u32 me = get_local_id(0); #if NW == 8 T2 w = fancyTrig_N(ND / WIDTH * me); @@ -49,7 +51,7 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { #else -void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { +void OVERLOAD fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { u32 me = get_local_id(0); #if !UNROLL_W @@ -59,7 +61,7 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { if (s > 1) { bar(); } fft_NW(u); tabMul(WIDTH / NW, trig, u, NW, s, me); - shufl( WIDTH / NW, lds, u, NW, s); + shufl(WIDTH / NW, lds, u, NW, s); } fft_NW(u); } @@ -67,15 +69,12 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { #endif - - - // New fft_WIDTH that uses more FMA instructions than the old fft_WIDTH. // The tabMul after fft8 only does a partial complex multiply, saving a mul-by-cosine for the next fft8 using FMA instructions. -// To maximize FMA opportunities we precompute tig values as cosine and sine/cosine rather than cosine and sine. +// To maximize FMA opportunities we precompute trig values as cosine and sine/cosine rather than cosine and sine. // The downside is sine/cosine cannot be computed with chained multiplies. -void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { +void OVERLOAD new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { u32 WG = WIDTH / NW; u32 me = get_local_id(0); @@ -209,6 +208,119 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { #endif } -void new_fft_WIDTH1(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 1); } -void new_fft_WIDTH2(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 2); } +// There are two version of new_fft_WIDTH in case we want to try saving some trig values from new_fft_WIDTH1 in LDS memory for later use in new_fft_WIDTH2. +void OVERLOAD new_fft_WIDTH1(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 1); } +void OVERLOAD new_fft_WIDTH2(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 2); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft_NW(F2 *u) { +#if NW == 4 + fft4(u); +#elif NW == 8 + fft8(u); +#else +#error NW +#endif +} + +void OVERLOAD fft_WIDTH(local F2 *lds, F2 *u, TrigFP32 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_W + __attribute__((opencl_unroll_hint(1))) +#endif + for (u32 s = 1; s < WIDTH / NW; s *= NW) { + if (s > 1) { bar(); } + fft_NW(u); + tabMul(WIDTH / NW, trig, u, NW, s, me); + shufl(WIDTH / NW, lds, u, NW, s); + } + fft_NW(u); +} + +void OVERLOAD new_fft_WIDTH1(local F2 *lds, F2 *u, TrigFP32 trig) { fft_WIDTH(lds, u, trig); } +void OVERLOAD new_fft_WIDTH2(local F2 *lds, F2 *u, TrigFP32 trig) { fft_WIDTH(lds, u, trig); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft_NW(GF31 *u) { +#if NW == 4 + fft4(u); +#elif NW == 8 + fft8(u); +#else +#error NW +#endif +} + +void OVERLOAD fft_WIDTH(local GF31 *lds, GF31 *u, TrigGF31 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_W + __attribute__((opencl_unroll_hint(1))) +#endif + for (u32 s = 1; s < WIDTH / NW; s *= NW) { + if (s > 1) { bar(); } + fft_NW(u); + tabMul(WIDTH / NW, trig, u, NW, s, me); + shufl(WIDTH / NW, lds, u, NW, s); + } + fft_NW(u); +} + +void OVERLOAD new_fft_WIDTH1(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_WIDTH(lds, u, trig); } +void OVERLOAD new_fft_WIDTH2(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_WIDTH(lds, u, trig); } + +#endif + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft_NW(GF61 *u) { +#if NW == 4 + fft4(u); +#elif NW == 8 + fft8(u); +#else +#error NW +#endif +} + +void OVERLOAD fft_WIDTH(local GF61 *lds, GF61 *u, TrigGF61 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_W + __attribute__((opencl_unroll_hint(1))) +#endif + for (u32 s = 1; s < WIDTH / NW; s *= NW) { + if (s > 1) { bar(); } + fft_NW(u); + tabMul(WIDTH / NW, trig, u, NW, s, me); + shufl(WIDTH / NW, lds, u, NW, s); + } + fft_NW(u); +} + +void OVERLOAD new_fft_WIDTH1(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_WIDTH(lds, u, trig); } +void OVERLOAD new_fft_WIDTH2(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_WIDTH(lds, u, trig); } + +#endif diff --git a/src/cl/math.cl b/src/cl/math.cl index a97ea506..d112a180 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -4,6 +4,44 @@ #include "base.cl" +// Access parts of a 64-bit value + +u32 lo32(u64 x) { return (u32) x; } +u32 hi32(u64 x) { return (u32) (x >> 32); } + +// A primitive partial implementation of an i96/u96 integer type +typedef union { + struct { u32 lo32; u32 mid32; u32 hi32; } a; + struct { u64 lo64; u32 hi32; } c; +} i96; +i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.a.hi32 = h, val.a.mid32 = m, val.a.lo32 = l; return val; } +i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.a.hi32 = hi32(h), val.a.mid32 = lo32(h), val.a.lo32 = l; return val; } +i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.c.hi32 = h, val.c.lo64 = l; return val; } +void i96_add(i96 *val, i96 x) { u64 lo64 = val->c.lo64 + x.c.lo64; val->c.hi32 += x.c.hi32 + (lo64 < val->c.lo64); val->c.lo64 = lo64; } +void OVERLOAD i96_sub(i96 *val, i96 x) { u64 lo64 = val->c.lo64 - x.c.lo64; val->c.hi32 -= x.c.hi32 + (lo64 > val->c.lo64); val->c.lo64 = lo64; } +void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } +void i96_mul(i96 *val, u32 x) { u64 t = (u64)val->a.lo32 * x; val->a.lo32 = (u32)t; t = (u64)val->a.mid32 * x + (t >> 32); val->a.mid32 = (u32)t; val->a.hi32 = val->a.hi32 * x + (u32)(t >> 32); } +u32 i96_hi32(i96 val) { return val.c.hi32; } +u64 i96_lo64(i96 val) { return val.c.lo64; } +u64 i96_hi64(i96 val) { return ((u64) val.a.hi32 << 32) + val.a.mid32; } +u32 i96_lo32(i96 val) { return val.a.lo32; } + +// The X2 family of macros and SWAP are #defines because OpenCL does not allow pass by reference. +// With NTT support added, we need to turn these macros into overloaded routines. +#define X2(a, b) X2_internal(&(a), &(b)) // a = a + b, b = a - b +#define X2conjb(a, b) X2conjb_internal(&(a), &(b)) // X2(a, conjugate(b)) +#define X2_mul_t4(a, b) X2_mul_t4_internal(&(a), &(b)) // X2(a, b), b = mul_t4(b) +#define X2_mul_t8(a, b) X2_mul_t8_internal(&(a), &(b)) // X2(a, b), b = mul_t8(b) +#define X2_mul_3t8(a, b) X2_mul_3t8_internal(&(a), &(b)) // X2(a, b), b = mul_3t8(b) +#define X2_conja(a, b) X2_conja_internal(&(a), &(b)) // X2(a, b), a = conjugate(a) // NOT USED +#define X2_conjb(a, b) X2_conjb_internal(&(a), &(b)) // X2(a, b), b = conjugate(b) +#define SWAP(a, b) SWAP_internal(&(a), &(b)) // a = b, b = a +#define SWAP_XY(a) U2((a).y, (a).x) // Swap real and imaginary components of a + +#if FFT_FP64 + +T2 OVERLOAD conjugate(T2 a) { return U2(a.x, -a.y); } + // Multiply by 2 without using floating point instructions. This is a little sloppy as an input of zero returns 2^-1022. T OVERLOAD mul2(T a) { int2 tmp = as_int2(a); tmp.y += 0x00100000; /* Bump exponent by 1 */ return (as_double(tmp)); } T2 OVERLOAD mul2(T2 a) { return U2(mul2(a.x), mul2(a.y)); } @@ -16,37 +54,28 @@ T2 OVERLOAD mulminus2(T2 a) { return U2(mulminus2(a.x), mulminus2(a.y)); } T OVERLOAD fancyMul(T a, T b) { return fma(a, b, a); } T2 OVERLOAD fancyMul(T2 a, T2 b) { return U2(fancyMul(a.x, b.x), fancyMul(a.y, b.y)); } -T2 cmul(T2 a, T2 b) { -#if 1 - return U2(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); -#else - return U2(fma(a.y, -b.y, a.x * b.x), fma(a.x, b.y, a.y * b.x)); -#endif -} +// Square a complex number +T2 OVERLOAD csq(T2 a) { return U2(fma(a.x, a.x, - a.y * a.y), mul2(a.x) * a.y); } +// a^2 + c +T2 OVERLOAD csqa(T2 a, T2 c) { return U2(fma(a.x, a.x, fma(a.y, -a.y, c.x)), fma(mul2(a.x), a.y, c.y)); } +// Same as csq(a), -a +T2 OVERLOAD csq_neg(T2 a) { return U2(fma(-a.x, a.x, a.y * a.y), mulminus2(a.x) * a.y); } // NOT USED + +// Complex multiply +T2 OVERLOAD cmul(T2 a, T2 b) { return U2(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); } -T2 conjugate(T2 a) { return U2(a.x, -a.y); } +T2 OVERLOAD cfma(T2 a, T2 b, T2 c) { return U2(fma(a.x, b.x, fma(a.y, -b.y, c.x)), fma(a.y, b.x, fma(a.x, b.y, c.y))); } -T2 cmul_by_conjugate(T2 a, T2 b) { return cmul(a, conjugate(b)); } +T2 OVERLOAD cmul_by_conjugate(T2 a, T2 b) { return cmul(a, conjugate(b)); } // Multiply a by b and conjugate(b). This saves 2 multiplies. -void cmul_a_by_b_and_conjb(T2 *res1, T2 *res2, T2 a, T2 b) { +void OVERLOAD cmul_a_by_b_and_conjb(T2 *res1, T2 *res2, T2 a, T2 b) { T axbx = a.x * b.x; T aybx = a.y * b.x; res1->x = fma(a.y, -b.y, axbx), res1->y = fma(a.x, b.y, aybx); res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); } -T2 cfma(T2 a, T2 b, T2 c) { -#if 1 - return U2(fma(a.x, b.x, fma(a.y, -b.y, c.x)), fma(a.y, b.x, fma(a.x, b.y, c.y))); -#else - return U2(fma(a.y, -b.y, fma(a.x, b.x, c.x)), fma(a.x, b.y, fma(a.y, b.x, c.y))); -#endif -} - -// Square a complex number -T2 csq(T2 a) { return U2(fma(a.x, a.x, - a.y * a.y), mul2(a.x) * a.y); } - // Square a (cos,sin) complex number. Fancy squaring returns a fancy value. Defancy squares a fancy number returning a non-fancy number. T2 csqTrig(T2 a) { T two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } T2 csqTrigFancy(T2 a) { T two_ay = mul2(a.y); return U2(-two_ay * a.y, fma(a.x, two_ay, two_ay)); } @@ -58,9 +87,6 @@ T2 ccubeTrig(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), f T2 ccubeTrigFancy(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, tmp - w.y)); } T2 ccubeTrigDefancy(T2 sq, T2 w) { T tmp = mul2(sq.y); T wx = w.x + 1; return U2(fma(tmp, -w.y, wx), fma(tmp, wx, -w.y)); } -// a^2 + c -T2 csqa(T2 a, T2 c) { return U2(fma(a.x, a.x, fma(a.y, -a.y, c.x)), fma(mul2(a.x), a.y, c.y)); } - // Complex a * (b + 1) // Useful for mul with twiddles of small angles, where the real part is stored with the -1 trick for increased precision T2 cmulFancy(T2 a, T2 b) { return cfma(a, b, a); } @@ -73,10 +99,9 @@ void cmul_a_by_fancyb_and_conjfancyb(T2 *res1, T2 *res2, T2 a, T2 b) { res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); } +T2 OVERLOAD mul_t4(T2 a) { return U2(-a.y, a.x); } // i.e. a * i -T2 mul_t4(T2 a) { return U2(-a.y, a.x); } // i.e. a * i - -T2 mul_t8(T2 a) { // mul(a, U2( 1, 1)) * (T)(M_SQRT1_2); } +T2 OVERLOAD mul_t8(T2 a) { // mul(a, U2(1, 1)) * (T)(M_SQRT1_2); } // One mul, two FMAs T ay = a.y * M_SQRT1_2; return U2(fma(a.x, M_SQRT1_2, -ay), fma(a.x, M_SQRT1_2, ay)); @@ -84,7 +109,7 @@ T2 mul_t8(T2 a) { // mul(a, U2( 1, 1)) * (T)(M_SQRT1_2); } // return U2(a.x - a.y, a.x + a.y) * M_SQRT1_2; } -T2 mul_3t8(T2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } +T2 OVERLOAD mul_3t8(T2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } // One mul, two FMAs T ay = a.y * M_SQRT1_2; return U2(fma(-a.x, M_SQRT1_2, -ay), fma(a.x, M_SQRT1_2, -ay)); @@ -92,24 +117,667 @@ T2 mul_3t8(T2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } // return U2(-(a.x + a.y), a.x - a.y) * M_SQRT1_2; } -T2 swap(T2 a) { return U2(a.y, a.x); } -T2 addsub(T2 a) { return U2(a.x + a.y, a.x - a.y); } - -#define X2(a, b) { T2 t = a; a = t + b; b = t - b; } +// Return a+b and a-b +void OVERLOAD X2_internal(T2 *a, T2 *b) { T2 t = *a; *a = t + *b; *b = t - *b; } // Same as X2(a, b), b = mul_t4(b) -#define X2_mul_t4(a, b) { X2(a, b); b = mul_t4(b); } -// { T2 t = a; a = a + b; t.x = t.x - b.x; b.x = b.y - t.y; b.y = t.x; } +void OVERLOAD X2_mul_t4_internal(T2 *a, T2 *b) { T2 t = *a; *a = *a + *b; t.x = t.x - b->x; b->x = b->y - t.y; b->y = t.x; } // Same as X2(a, conjugate(b)) -#define X2conjb(a, b) { T2 t = a; a.x = a.x + b.x; a.y = a.y - b.y; b.x = t.x - b.x; b.y = t.y + b.y; } +void OVERLOAD X2conjb_internal(T2 *a, T2 *b) { T2 t = *a; a->x = a->x + b->x; a->y = a->y - b->y; b->x = t.x - b->x; b->y = t.y + b->y; } // Same as X2(a, b), a = conjugate(a) -#define X2conja(a, b) { T2 t = a; (a).x = (a).x + (b).x; (a).y = -(a).y - (b).y; b = t - b; } +void OVERLOAD X2_conja_internal(T2 *a, T2 *b) { T2 t = *a; a->x = a->x + b->x; a->y = - (a->y + b->y); *b = t - *b; } -#define SWAP(a, b) { T2 t = a; a = b; b = t; } +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(T2 *a, T2 *b) { T2 t = *a; *a = t + *b; b->x = t.x - b->x; b->y = b->y - t.y; } + +void OVERLOAD SWAP_internal(T2 *a, T2 *b) { T2 t = *a; *a = *b; *b = t; } T2 fmaT2(T a, T2 b, T2 c) { return fma(U2(a, a), b, c); } // a = c + sin * d; b = c - sin * d; #define fma_addsub(a, b, sin, c, d) { T2 t = c + sin * d; b = c - sin * d; a = t; } + +T2 OVERLOAD addsub(T2 a) { return U2(a.x + a.y, a.x - a.y); } + +// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) +// which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 +T2 foo2(T2 a, T2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } + +// computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) +T2 foo(T2 a) { return foo2(a, a); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F2 OVERLOAD conjugate(F2 a) { return U2(a.x, -a.y); } + +// Multiply by 2 without using floating point instructions. This is a little sloppy as an input of zero returns 2^-126. +F OVERLOAD mul2(F a) { return a + a; } //{ int tmp = as_int(a); tmp += 0x00800000; /* Bump exponent by 1 */ return (as_float(tmp)); } +F2 OVERLOAD mul2(F2 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Multiply by -2 without using floating point instructions. This is a little sloppy as an input of zero returns -2^-126. +F OVERLOAD mulminus2(F a) { return -2.0f * a; } //{ int tmp = as_int(a); tmp += 0x80800000; /* Bump exponent by 1, flip sign bit */ return (as_float(tmp)); } +F2 OVERLOAD mulminus2(F2 a) { return U2(mulminus2(a.x), mulminus2(a.y)); } + +// a * (b + 1) == a * b + a +F OVERLOAD fancyMul(F a, F b) { return fma(a, b, a); } +F2 OVERLOAD fancyMul(F2 a, F2 b) { return U2(fancyMul(a.x, b.x), fancyMul(a.y, b.y)); } + +// Square a complex number +F2 OVERLOAD csq(F2 a) { return U2(fma(a.x, a.x, - a.y * a.y), mul2(a.x) * a.y); } +// a^2 + c +F2 OVERLOAD csqa(F2 a, F2 c) { return U2(fma(a.x, a.x, fma(a.y, -a.y, c.x)), fma(mul2(a.x), a.y, c.y)); } +// Same as csq(a), -a +F2 OVERLOAD csq_neg(F2 a) { return U2(fma(-a.x, a.x, a.y * a.y), mulminus2(a.x) * a.y); } // NOT USED + +// Complex multiply +F2 OVERLOAD cmul(F2 a, F2 b) { return U2(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); } + +F2 OVERLOAD cfma(F2 a, F2 b, F2 c) { return U2(fma(a.x, b.x, fma(a.y, -b.y, c.x)), fma(a.y, b.x, fma(a.x, b.y, c.y))); } + +F2 OVERLOAD cmul_by_conjugate(F2 a, F2 b) { return cmul(a, conjugate(b)); } + +// Multiply a by b and conjugate(b). This saves 2 multiplies. +void OVERLOAD cmul_a_by_b_and_conjb(F2 *res1, F2 *res2, F2 a, F2 b) { + F axbx = a.x * b.x; + F aybx = a.y * b.x; + res1->x = fma(a.y, -b.y, axbx), res1->y = fma(a.x, b.y, aybx); + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); +} + +// Square a (cos,sin) complex number. Fancy squaring returns a fancy value. Defancy squares a fancy number returning a non-fancy number. +F2 csqTrig(F2 a) { F two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } +F2 csqTrigFancy(F2 a) { F two_ay = mul2(a.y); return U2(-two_ay * a.y, fma(a.x, two_ay, two_ay)); } +F2 csqTrigDefancy(F2 a) { F two_ay = mul2(a.y); return U2(fma (-two_ay, a.y, 1), fma(a.x, two_ay, two_ay)); } + +// Cube a complex number w (cos,sin) given w^2 and w. The squared input can be either fancy or not fancy. +// Fancy cCube takes a fancy w argument and returns a fancy value. Defancy takes a fancy w argument and returns a non-fancy value. +F2 ccubeTrig(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } +F2 ccubeTrigFancy(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, tmp - w.y)); } +F2 ccubeTrigDefancy(F2 sq, F2 w) { F tmp = mul2(sq.y); F wx = w.x + 1; return U2(fma(tmp, -w.y, wx), fma(tmp, wx, -w.y)); } + +// Complex a * (b + 1) +// Useful for mul with twiddles of small angles, where the real part is stored with the -1 trick for increased precision +F2 cmulFancy(F2 a, F2 b) { return cfma(a, b, a); } + +// Multiply a by fancy b and conjugate(fancy b). This saves 2 FMAs. +void cmul_a_by_fancyb_and_conjfancyb(F2 *res1, F2 *res2, F2 a, F2 b) { + F axbx = fma(a.x, b.x, a.x); + F aybx = fma(a.y, b.x, a.y); + res1->x = fma(a.y, -b.y, axbx), res1->y = fma(a.x, b.y, aybx); + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); +} + +F2 OVERLOAD mul_t4(F2 a) { return U2(-a.y, a.x); } // i.e. a * i + +F2 OVERLOAD mul_t8(F2 a) { // mul(a, U2(1, 1)) * (T)(M_SQRT1_2); } + // One mul, two FMAs + F ay = a.y * (float) M_SQRT1_2; + return U2(fma(a.x, (float) M_SQRT1_2, -ay), fma(a.x, (float) M_SQRT1_2, ay)); +// Two adds, two muls +// return U2(a.x - a.y, a.x + a.y) * M_SQRT1_2; +} + +F2 OVERLOAD mul_3t8(F2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } + // One mul, two FMAs + F ay = a.y * (float) M_SQRT1_2; + return U2(fma(-a.x, (float) M_SQRT1_2, -ay), fma(a.x, (float) M_SQRT1_2, -ay)); +// Two adds, two muls +// return U2(-(a.x + a.y), a.x - a.y) * M_SQRT1_2; +} + +// Return a+b and a-b +void OVERLOAD X2_internal(F2 *a, F2 *b) { F2 t = *a; *a = t + *b; *b = t - *b; } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(F2 *a, F2 *b) { F2 t = *a; *a = *a + *b; t.x = t.x - b->x; b->x = b->y - t.y; b->y = t.x; } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(F2 *a, F2 *b) { F2 t = *a; a->x = a->x + b->x; a->y = a->y - b->y; b->x = t.x - b->x; b->y = t.y + b->y; } + +// Same as X2(a, b), a = conjugate(a) +void OVERLOAD X2_conja_internal(F2 *a, F2 *b) { F2 t = *a; a->x = a->x + b->x; a->y = - (a->y + b->y); *b = t - *b; } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(F2 *a, F2 *b) { F2 t = *a; *a = t + *b; b->x = t.x - b->x; b->y = b->y - t.y; } + +void OVERLOAD SWAP_internal(F2 *a, F2 *b) { F2 t = *a; *a = *b; *b = t; } + +F2 fmaT2(F a, F2 b, F2 c) { return fma(U2(a, a), b, c); } + +// a = c + sin * d; b = c - sin * d; +#define fma_addsub(a, b, sin, c, d) { F2 t = c + sin * d; b = c - sin * d; a = t; } + +F2 OVERLOAD addsub(F2 a) { return U2(a.x + a.y, a.x - a.y); } + +// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) +// which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 +F2 foo2(F2 a, F2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } + +// computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) +F2 foo(F2 a) { return foo2(a, a); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// bits in reduced mod M. +#define M31 ((((Z31) 1) << 31) - 1) + +#if 0 // Version that keeps results strictly in the 0..M31-1 range + +Z31 OVERLOAD mod(Z31 a) { return (a & M31) + (a >> 31); } // GWBUG: This could be larger than M31 (unless a is result of an add), need a wesk and strong mod + +Z31 OVERLOAD add(Z31 a, Z31 b) { Z31 t = a + b; return t - (t >= M31 ? M31 : 0); } //GWBUG - an if stmt may be faster +GF31 OVERLOAD add(GF31 a, GF31 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z31 OVERLOAD sub(Z31 a, Z31 b) { Z31 t = a - b; return t + (t >= M31 ? M31 : 0); } //GWBUG - an if stmt may be faster. So might "(i64) t < 0". +GF31 OVERLOAD sub(GF31 a, GF31 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +Z31 OVERLOAD neg(Z31 a) { return a == 0 ? 0 : M31 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +GF31 OVERLOAD neg(GF31 a) { return U2(neg(a.x), neg(a.y)); } + +Z31 OVERLOAD make_Z31(i32 a) { return (Z31) (a < 0 ? a + M31 : a); } // Handles signed values of a +Z31 OVERLOAD make_Z31(u32 a) { return (Z31) (a); } // a must be in range of 0 .. M31-1 +Z31 OVERLOAD make_Z31(i64 a) { if (a < 0) a += (((i64) M31 << 31) + M31); return add((Z31) (a & M31), (Z31) (a >> 31)); } // Handles 62-bit a values + +u32 get_Z31(Z31 a) { return a; } // Get balanced value in range 0 to M31-1 +i32 get_balanced_Z31(Z31 a) { return (a & 0xC0000000) ? (i32) a - M31 : (i32) a; } // Get balanced value in range -M31/2 to M31/2 + +// Assumes k reduced mod 31. +Z31 OVERLOAD shl(Z31 a, u32 k) { return ((a << k) + (a >> (31 - k))) & M31; } +GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } +Z31 OVERLOAD shr(Z31 a, u32 k) { return ((a >> k) + (a << (31 - k))) & M31; } +GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } + +Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } + +Z31 OVERLOAD fma(Z31 a, Z31 b, Z31 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? + +// Multiply by 2 +Z31 OVERLOAD mul2(Z31 a) { return ((a + a) + (a >> 30)) & M31; } // GWBUG: Can we do better? +GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF31 OVERLOAD conjugate(GF31 a) { return U2(a.x, neg(a.y)); } + +// Complex square. input, output 31 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF31 OVERLOAD csq(GF31 a) { return U2(mul(add(a.x, a.y), sub(a.x, a.y)), mul2(mul(a.x, a.y))); } //GWBUG: Probably faster to double a.y and have a mul that takes non-normalized inputs + +// a^2 + c +GF31 OVERLOAD csqa(GF31 a, GF31 c) { return add(csq(a), c); } // GWBUG: inline csq so we only "mod" after adding c?? Find a way to use fma instructions + +// Complex mul +//GF31 OVERLOAD cmul(GF31 a, GF31 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} // GWBUG: Is a 3 multiply complex mul faster? See above +GF31 OVERLOAD cmul(GF31 a, GF31 b) { + Z31 k1 = mul(b.x, add(a.x, a.y)); + Z31 k2 = mul(a.x, sub(b.y, b.x)); + Z31 k3 = mul(a.y, add(b.y, b.x)); + return U2(sub(k1, k3), add(k1, k2)); +} + +GF31 OVERLOAD cfma(GF31 a, GF31 b, GF31 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? + +GF31 OVERLOAD cmul_by_conjugate(GF31 a, GF31 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate + +// Multiply a by b and conjugate(b). This saves 2 multiplies. +void OVERLOAD cmul_a_by_b_and_conjb(GF31 *res1, GF31 *res2, GF31 a, GF31 b) { + Z31 axbx = mul(a.x, b.x); + Z31 aybx = mul(a.y, b.x); + res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. +} + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (2^15, 2^15). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^15)^2 == 1 (mod M31). +GF31 OVERLOAD mul_t8(GF31 a) { return U2(shl(sub(a.x, a.y), 15), shl(add(a.x, a.y), 15)); } // GWBUG: Can caller use a version that does not negate real? is shl(neg) same as shr??? + +// mul with (-2^15, 2^15). (twiddle of 3*tau/8). +GF31 OVERLOAD mul_3t8(GF31 a) { return U2(shl(neg(add(a.x, a.y)), 15), shl(sub(a.x, a.y), 15)); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_t8(*b); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_3t8(*b); } //GWBUG: can we do better (elim a negate)? + +// Same as X2(a, b), a = conjugate(a) +void OVERLOAD X2_conja_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = *b; *b = t; } + +GF31 OVERLOAD addsub(GF31 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF31 OVERLOAD foo2(GF31 a, GF31 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } + + + + + +#elif 1 // This version is a little sloppy. Returns values in 0..M31 range //GWBUG (could this handle M31+1 too> neg() is hard. If so made_Z31(i64) is faster + +// Internal routine to return value in 0..M31 range +Z31 OVERLOAD mod(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return M31+1 + +Z31 OVERLOAD neg(Z31 a) { return M31 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +GF31 OVERLOAD neg(GF31 a) { return U2(neg(a.x), neg(a.y)); } + +Z31 OVERLOAD add(Z31 a, Z31 b) { return mod(a + b); } +GF31 OVERLOAD add(GF31 a, GF31 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z31 OVERLOAD sub(Z31 a, Z31 b) { return mod(a + neg(b)); } +GF31 OVERLOAD sub(GF31 a, GF31 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +Z31 OVERLOAD make_Z31(i32 a) { return (Z31) (a < 0 ? a + M31 : a); } // Handles signed values of a +Z31 OVERLOAD make_Z31(u32 a) { return (Z31) (a); } // a must be in range of 0 .. M31-1 +Z31 OVERLOAD make_Z31(i64 a) { if (a < 0) a += (((i64) M31 << 31) + M31); return add((Z31) (a & M31), (Z31) (a >> 31)); } // Handles 62-bit a values + +u32 get_Z31(Z31 a) { return a == M31 ? 0 : a; } // Get value in range 0 to M31-1 +i32 get_balanced_Z31(Z31 a) { return (a & 0xC0000000) ? (i32) a - M31 : (i32) a; } // Get balanced value in range -M31/2 to M31/2 + +// Assumes k reduced mod 31. +Z31 OVERLOAD shl(Z31 a, u32 k) { return ((a << k) + (a >> (31 - k))) & M31; } +GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } +Z31 OVERLOAD shr(Z31 a, u32 k) { return ((a >> k) + (a << (31 - k))) & M31; } +GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } + +//Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } //GWBUG. is M31 * M31 a problem???? I think so! needs double mod +Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return mod(add((Z31) (t & M31), (Z31) (t >> 31))); } //Fixes the M31 * M31 problem + +Z31 OVERLOAD fma(Z31 a, Z31 b, Z31 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? + +// Multiply by 2 +Z31 OVERLOAD mul2(Z31 a) { return add(a, a); } +GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF31 OVERLOAD conjugate(GF31 a) { return U2(a.x, neg(a.y)); } + +// Complex square. input, output 31 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF31 OVERLOAD csq(GF31 a) { return U2(mul(add(a.x, a.y), sub(a.x, a.y)), mul2(mul(a.x, a.y))); } //GWBUG: Probably faster to double a.y and have a mul that takes non-normalized inputs + +// a^2 + c +GF31 OVERLOAD csqa(GF31 a, GF31 c) { return add(csq(a), c); } // GWBUG: inline csq so we only "mod" after adding c?? Find a way to use fma instructions + +// Complex mul +//GF31 OVERLOAD cmul(GF31 a, GF31 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} // GWBUG: Is a 3 multiply complex mul faster? See above +GF31 OVERLOAD cmul(GF31 a, GF31 b) { + Z31 k1 = mul(b.x, add(a.x, a.y)); + Z31 k2 = mul(a.x, sub(b.y, b.x)); + Z31 k3 = mul(a.y, add(b.y, b.x)); + return U2(sub(k1, k3), add(k1, k2)); +} + +GF31 OVERLOAD cfma(GF31 a, GF31 b, GF31 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? + +GF31 OVERLOAD cmul_by_conjugate(GF31 a, GF31 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate + +// Multiply a by b and conjugate(b). This saves 2 multiplies. +void OVERLOAD cmul_a_by_b_and_conjb(GF31 *res1, GF31 *res2, GF31 a, GF31 b) { + Z31 axbx = mul(a.x, b.x); + Z31 aybx = mul(a.y, b.x); + res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. +} + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (2^15, 2^15). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^15)^2 == 1 (mod M31). +GF31 OVERLOAD mul_t8(GF31 a) { return U2(shl(sub(a.x, a.y), 15), shl(add(a.x, a.y), 15)); } // GWBUG: Can caller use a version that does not negate real? is shl(neg) same as shr??? + +// mul with (-2^15, 2^15). (twiddle of 3*tau/8). +GF31 OVERLOAD mul_3t8(GF31 a) { return U2(shl(neg(add(a.x, a.y)), 15), shl(sub(a.x, a.y), 15)); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_t8(*b); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_3t8(*b); } //GWBUG: can we do better (elim a negate)? + +// Same as X2(a, b), a = conjugate(a) +void OVERLOAD X2_conja_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = *b; *b = t; } + +GF31 OVERLOAD addsub(GF31 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF31 OVERLOAD foo2(GF31 a, GF31 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } + +#endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// bits in reduced mod M. +#define M61 ((((Z61) 1) << 61) - 1) + +Z61 OVERLOAD make_Z61(i32 a) { return (Z61) (a < 0 ? (i64) a + M61 : (i64) a); } // Handles all values of a +Z61 OVERLOAD make_Z61(i64 a) { return (Z61) (a < 0 ? a + M61 : a); } // a must be in range of -M61 .. M61-1 +Z61 OVERLOAD make_Z61(u32 a) { return (Z61) (a); } // Handles all values of a +Z61 OVERLOAD make_Z61(u64 a) { return (Z61) (a); } // a must be in range of 0 .. M61-1 + +#if 0 // Slower version that keeps results strictly in the range 0 .. M61-1 + +u64 OVERLOAD get_Z61(Z61 a) { return a; } // Get value in range 0 to M61-1 +i64 OVERLOAD get_balanced_Z61(Z61 a) { return (hi32(a) & 0xF0000000) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 + +Z61 OVERLOAD neg(Z61 a) { return a == 0 ? 0 : M61 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +GF61 OVERLOAD neg(GF61 a) { return U2(neg(a.x), neg(a.y)); } + +Z61 OVERLOAD add(Z61 a, Z61 b) { Z61 t = a + b; Z61 m = t - M61; return (m & 0x8000000000000000ULL) ? t : m; } +//Z61 OVERLOAD add(Z61 a, Z61 b) { Z61 t = a + b; Z61 m = t - M61; return t < m ? t : m; } // Slower on TitanV +//Z61 OVERLOAD add(Z61 a, Z61 b) { Z61 t = a + b; return t - (t >= M61 ? M61 : 0); } // Slower on TitanV +GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z61 OVERLOAD sub(Z61 a, Z61 b) { Z61 t = a - b; return t + (((i64) t >> 63) & 0x1FFFFFFFFFFFFFFFULL); } +//Z61 OVERLOAD sub(Z61 a, Z61 b) { Z61 t = a - b; Z61 p = t + M61; return (t & 0x8000000000000000ULL) ? p : t; } // Better??? +//Z61 OVERLOAD sub(Z61 a, Z61 b) { Z61 t = a - b; return t + (t >= M61 ? M61 : 0); } // Slower on TitanV +// BETTER???: t = a - b; carry_mask = sbb x, x; (generates 32 bits of 0 or 1; return t + make_carry_mask_64_bits +GF61 OVERLOAD sub(GF61 a, GF61 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +// Assumes k reduced mod 61. +Z61 OVERLOAD shl(Z61 a, u32 k) { return ((a << k) + (a >> (61 - k))) & M61; } //GWBUG: Make sure & M61 operates on just one u32 +GF61 OVERLOAD shl(GF61 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } +Z61 OVERLOAD shr(Z61 a, u32 k) { return ((a >> k) + (a << (61 - k))) & M61; } //GWBUG: Make sure & M61 operates on just one u32. & M61 not needed? +GF61 OVERLOAD shr(GF61 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } + +ulong2 wideMul(u64 ab, u64 cd) { + u128 r = (u128) ab * (u128) cd; + return U2((u64) r, (u64) (r >> 64)); +} + +Z61 OVERLOAD mul(Z61 a, Z61 b) { + ulong2 ab = wideMul(a, b); + u64 lo = ab.x, hi = ab.y; + u64 lo61 = lo & M61, hi61 = (hi << 3) + (lo >> 61); + return add(lo61, hi61); +} + +Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? + +// Multiply by 2 +Z61 OVERLOAD mul2(Z61 a) { return ((a + a) + (a >> 60)) & M61; } // GWBUG: Make sure "+ a>>60" does an add to lower u32 without a followup adc. +GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } + +// Complex square. input, output 61 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF61 OVERLOAD csq(GF61 a) { return U2(mul(add(a.x, a.y), sub(a.x, a.y)), mul2(mul(a.x, a.y))); } //GWBUG: Probably faster to double a.y and have a mul that takes non-normalized inputs + +// a^2 + c +GF61 OVERLOAD csqa(GF61 a, GF61 c) { return add(csq(a), c); } // GWBUG: inline csq so we only "mod" after adding c?? Find a way to use fma instructions + +// Complex mul +//GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} // GWBUG: Is a 3 multiply complex mul faster? See above +GF61 OVERLOAD cmul(GF61 a, GF61 b) { + Z61 k1 = mul(b.x, add(a.x, a.y)); + Z61 k2 = mul(a.x, sub(b.y, b.x)); + Z61 k3 = mul(a.y, add(b.y, b.x)); + return U2(sub(k1, k3), add(k1, k2)); +} + +GF61 OVERLOAD cfma(GF61 a, GF61 b, GF61 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? + +GF61 OVERLOAD cmul_by_conjugate(GF61 a, GF61 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate + +// Multiply a by b and conjugate(b). This saves 2 multiplies. +void OVERLOAD cmul_a_by_b_and_conjb(GF61 *res1, GF61 *res2, GF61 a, GF61 b) { + Z61 axbx = mul(a.x, b.x); + Z61 aybx = mul(a.y, b.x); + res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. +} + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (-2^30, -2^30). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^30)^2 == 1 (mod M61). +GF61 OVERLOAD mul_t8(GF61 a) { return shl(U2(sub(a.y, a.x), neg(add(a.x, a.y))), 30); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (2^30, -2^30). (twiddle of 3*tau/8). +GF61 OVERLOAD mul_3t8(GF61 a) { return shl(U2(add(a.x, a.y), sub(a.y, a.x)), 30); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { X2(*a, *b); *b = mul_t8(*b); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF61 *a, GF61 *b) { X2(*a, *b); *b = mul_3t8(*b); } + +// Same as X2(a, b), a = conjugate(a) +void OVERLOAD X2_conja_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = *b; *b = t; } + +GF61 OVERLOAD addsub(GF61 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF61 OVERLOAD foo2(GF61 a, GF61 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } + +// The following routines can be used to reduce mod M61 operations (in the other Z61 implementations). +// Caller must track how many M61s need to be added to make positive values for subtractions. +// In function names, "q" stands for quick, "s" stands for slow (i.e. does mod). +// These functions are untested with this strict Z61 implementation. Callers need to eliminate all uses of + or - operators. + +Z61 OVERLOAD mod(Z61 a) { return a; } +GF61 OVERLOAD mod(GF61 a) { return a; } +Z61 OVERLOAD neg(Z61 a, u32 m61_count) { return neg(a); } +GF61 OVERLOAD neg(GF61 a, u32 m61_count) { return neg(a); } +Z61 OVERLOAD addq(Z61 a, Z61 b) { return add(a, b); } +GF61 OVERLOAD addq(GF61 a, GF61 b) { return add(a, b); } +Z61 OVERLOAD subq(Z61 a, Z61 b, u32 m61_count) { return sub(a, b); } +GF61 OVERLOAD subq(GF61 a, GF61 b, u32 m61_count) { return sub(a, b); } +Z61 OVERLOAD subs(Z61 a, Z61 b, u32 m61_count) { return sub(a, b); } +GF61 OVERLOAD subs(GF61 a, GF61 b, u32 m61_count) { return sub(a, b); } +void OVERLOAD X2q(GF61 *a, GF61 *b, u32 m61_count) { X2_internal(a, b); } +void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, u32 m61_count) { X2_mul_t4_internal(a, b); } +void OVERLOAD X2s(GF61 *a, GF61 *b, u32 m61_count) { X2_internal(a, b); } +void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, u32 m61_count) { X2_conjb_internal(a, b); } + + + + +// Philosophy: This Z61/GF61 implementation uses faster, sloppier mod M61 reduction where the end result is in the range 0..M61+epsilon. +// This implementation also handles subtractions by adding enough M61s to make a value positive. This allows us to always deal with positive +// intermediate results. The downside is that a caller using the sloppy/quick routines must keep track of how large unreduced values can get. +// An alternative implementation is to have Z61 be an i64 (costs us a precious bit of precision) but is surprisingly slower (at least on TitanV) because +// mod(a - b), where the mod routinue uses a signed right shift is slower than +// mod(a + (M61*2 - b)) where the mod routine uses an unsigned shift right. +// However, a long string of subtracts (example, fft8 does 3 subtracts before mod M61 might be better off using negative intermediate results. +// The mul routine (and obviously csq and cmul) must use only positive values as __int128 multiply is very slow. + +#elif 1 // Faster version that keeps results in the range 0 .. M61+epsilon + +u64 OVERLOAD get_Z61(Z61 a) { Z61 m = a - M61; return (m & 0x8000000000000000ULL) ? a : m; } // Get value in range 0 to M61-1 +i64 OVERLOAD get_balanced_Z61(Z61 a) { return (hi32(a) & 0xF0000000) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 + +// Internal routine to bring Z61 value into the range 0..M61+epsilon +Z61 OVERLOAD mod(Z61 a) { return (a & M61) + (a >> 61); } +GF61 OVERLOAD mod(GF61 a) { return U2(mod(a.x), mod(a.y)); } +// Internal routine to negate a value by adding the specified number of M61s -- no mod M61 reduction +Z61 OVERLOAD neg(Z61 a, u32 m61_count) { return m61_count * M61 - a; } +GF61 OVERLOAD neg(GF61 a, u32 m61_count) { return U2(neg(a.x, m61_count), neg(a.y, m61_count)); } + +Z61 OVERLOAD add(Z61 a, Z61 b) { return mod(a + b); } +GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z61 OVERLOAD sub(Z61 a, Z61 b) { return mod(a + neg(b, 2)); } +GF61 OVERLOAD sub(GF61 a, GF61 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + + Z61 OVERLOAD neg(Z61 a) { return mod (neg(a, 2)); } // GWBUG: Examine all callers to see if neg call can be avoided +GF61 OVERLOAD neg(GF61 a) { return U2(neg(a.x), neg(a.y)); } + +// Assumes k reduced mod 61. +Z61 OVERLOAD shr(Z61 a, u32 k) { return (a >> k) + ((a << (61 - k)) & M61); } // Return range 0..M61+2^(61-k), can handle 64-bit inputs but small k is big epsilon +GF61 OVERLOAD shr(GF61 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } +Z61 OVERLOAD shl(Z61 a, u32 k) { return shr(a, 61 - k); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k yields big epsilon +//Z61 OVERLOAD shl(Z61 a, u32 k) { return mod(a << k) + ((a >> (64 - k)) << 3); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k is big epsilon +//Z61 OVERLOAD shl(Z61 a, u32 k) { return mod((a << k) + ((a >> (64 - k)) << 3)); } // Return range 0..M61+epsilon, input must be M61+epsilon a full 62-bit value can overflow +GF61 OVERLOAD shl(GF61 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } + +ulong2 wideMul(u64 ab, u64 cd) { + u128 r = (u128) ab * (u128) cd; + return U2((u64) r, (u64) (r >> 64)); +} + +Z61 OVERLOAD weakMul(Z61 a, Z61 b) { // a*b must fit in 125 bits, result will as large as a*b >> 61 + ulong2 ab = wideMul(a, b); + u64 lo = ab.x, hi = ab.y; + u64 lo61 = lo & M61, hi61 = (hi << 3) + (lo >> 61); + return lo61 + hi61; +} + +Z61 OVERLOAD mul(Z61 a, Z61 b) { return mod(weakMul(a, b)); } + +Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return mod(weakMul(a, b) + c); } // GWBUG: Can we do better? + +// Multiply by 2 +Z61 OVERLOAD mul2(Z61 a) { return add(a, a); } +GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } + +// Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF61 OVERLOAD csq(GF61 a) { return U2(mul(a.x + a.y, mod(a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } + +// a^2 + c +GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(mod(weakMul(a.x + a.y, mod(a.x + neg(a.y, 2))) + c.x), mod(weakMul(a.x + a.x, a.y) + c.y)); } + +// Complex mul +//GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} +GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 2 extra bits in u64 + Z61 k1 = weakMul(b.x, a.x + a.y); // 61+e * 62+e bits = 123+e mult = 62+e bit result + Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2)); // 61+e * 63+e bits = 63+e bit result + Z61 k3 = weakMul(neg(a.y, 2), b.y + b.x); // 62 * 62+e bits = 63+e bit result + return U2(mod(k1 + k3), mod(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values +} + +GF61 OVERLOAD cfma(GF61 a, GF61 b, GF61 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? + +GF61 OVERLOAD cmul_by_conjugate(GF61 a, GF61 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate + +// Multiply a by b and conjugate(b). This saves 2 multiplies. +void OVERLOAD cmul_a_by_b_and_conjb(GF61 *res1, GF61 *res2, GF61 a, GF61 b) { + Z61 axbx = mul(a.x, b.x); + Z61 aybx = mul(a.y, b.x); + res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. +} + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (-2^30, -2^30). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^30)^2 == 1 (mod M61). +GF61 OVERLOAD mul_t8(GF61 a, u32 m61_count) { return shl(U2(a.y + neg(a.x, m61_count), neg(a.x + a.y, 2 * m61_count - 1)), 30); } +GF61 OVERLOAD mul_t8(GF61 a) { return mul_t8(a, 2); } + +// mul with (2^30, -2^30). (twiddle of 3*tau/8). +GF61 OVERLOAD mul_3t8(GF61 a, u32 m61_count) { return shl(U2(a.x + a.y, a.y + neg(a.x, m61_count)), 30); } +GF61 OVERLOAD mul_3t8(GF61 a) { return mul_3t8(a, 2); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = t + neg(*b, 2); *b = mul_t8(*b, 4); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = t + neg(*b, 2); *b = mul_3t8(*b, 4); } + +// Same as X2(a, b), a = conjugate(a) +void OVERLOAD X2_conja_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = *b; *b = t; } + +GF61 OVERLOAD addsub(GF61 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF61 OVERLOAD foo2(GF61 a, GF61 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } + +// The following routines can be used to reduce mod M61 operations. Caller must track how many M61s need to be added to make positive +// values for subtractions. In function names, "q" stands for quick, "s" stands for slow (i.e. does mod). + +Z61 OVERLOAD addq(Z61 a, Z61 b) { return a + b; } +GF61 OVERLOAD addq(GF61 a, GF61 b) { return U2(addq(a.x, b.x), addq(a.y, b.y)); } + +Z61 OVERLOAD subq(Z61 a, Z61 b, u32 m61_count) { return a + neg(b, m61_count); } +GF61 OVERLOAD subq(GF61 a, GF61 b, u32 m61_count) { return U2(subq(a.x, b.x, m61_count), subq(a.y, b.y, m61_count)); } + +Z61 OVERLOAD subs(Z61 a, Z61 b, u32 m61_count) { return mod(a + neg(b, m61_count)); } +GF61 OVERLOAD subs(GF61 a, GF61 b, u32 m61_count) { return U2(subs(a.x, b.x, m61_count), subs(a.y, b.y, m61_count)); } + +void OVERLOAD X2q(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; *b = t + neg(*b, m61_count); } +void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } + +void OVERLOAD X2s(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = add(t, *b); *b = subs(t, *b, m61_count); } +void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, m61_count); b->y = subs(b->y, t.y, m61_count); } + +#endif + +#endif diff --git a/src/cl/middle.cl b/src/cl/middle.cl index f1a40b75..ac649881 100644 --- a/src/cl/middle.cl +++ b/src/cl/middle.cl @@ -34,6 +34,8 @@ #define MIDDLE_OUT_LDS_TRANSPOSE 1 #endif +#if FFT_FP64 || NTT_GF61 + //**************************************************************************************** // Pair of routines to write data from carryFused and read data into fftMiddleIn //**************************************************************************************** @@ -50,7 +52,7 @@ // u[i] i ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) // y ranges 0...SMALL_HEIGHT-1 (multiples of one) -void writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { +void OVERLOAD writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { #if PAD_SIZE > 0 u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; out += line * WIDTH + line * PAD_SIZE + line / SMALL_HEIGHT * BIG_PAD_SIZE + (u32) get_local_id(0); // One pad every line + a big pad every SMALL_HEIGHT lines @@ -61,7 +63,7 @@ void writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { #endif } -void readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { +void OVERLOAD readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { #if PAD_SIZE > 0 // Each work group reads successive y's which increments by one pad size. // Rather than having u[i] also increment by one, we choose a larger pad increment @@ -86,7 +88,7 @@ void readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { // x ranges 0...SMALL_HEIGHT-1 (multiples of one) (also known as 0...G_H-1 and 0...NH-1) // y ranges 0...MIDDLE*WIDTH-1 (multiples of SMALL_HEIGHT) -void writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) +void OVERLOAD writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) { //u32 SIZEY = IN_WG / IN_SIZEX; //u32 num_x_chunks = WIDTH / IN_SIZEX; // Number of x chunks @@ -121,7 +123,7 @@ void writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) // Read a line for tailFused or fftHin // This reads partially transposed data as written by fftMiddleIn -void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { +void OVERLOAD readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { u32 SIZEY = IN_WG / IN_SIZEX; #if PAD_SIZE > 0 @@ -145,8 +147,8 @@ void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { u32 fftMiddleIn_y_incr = G_H; // The increment to next fftMiddleIn y value u32 chunk_y_incr = fftMiddleIn_y_incr / SIZEY; // The increment to next fftMiddleIn chunk_y value for (i32 i = 0; i < NH; ++i) { -// u32 fftMiddleIn_y = i * G_H + me; // The fftMiddleIn y value -// u32 chunk_y = fftMiddleIn_y / SIZEY; // The fftMiddleIn chunk_y value +// u32 fftMiddleIn_y = i * G_H + me; // The fftMiddleIn y value +// u32 chunk_y = fftMiddleIn_y / SIZEY; // The fftMiddleIn chunk_y value u[i] = NTLOAD(in[chunk_y * (MIDDLE * IN_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleInLine did chunk_y += chunk_y_incr; } @@ -196,7 +198,7 @@ void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { // i in u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) // y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) (processed in batches of OUT_WG/OUT_SIZEX) -void writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { +void OVERLOAD writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { #if PAD_SIZE > 0 #if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; @@ -211,7 +213,7 @@ void writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { #endif } -void readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { +void OVERLOAD readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { #if PAD_SIZE > 0 #if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 // Each u[i] increments by one pad size. @@ -274,7 +276,7 @@ void readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { // adjusted to effect a transpose. Or caller must transpose the x and y values and send us an out pointer with thread_id added in. // In other words, caller is responsible for deciding the best way to transpose x and y values. -void writeMiddleOutLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) +void OVERLOAD writeMiddleOutLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) { //u32 SIZEY = OUT_WG / OUT_SIZEX; //u32 num_x_chunks = SMALL_HEIGHT / OUT_SIZEX; // Number of x chunks @@ -307,7 +309,7 @@ void writeMiddleOutLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) } // Read a line for carryFused or FFTW. This line was written by writeMiddleOutLine above. -void readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { +void OVERLOAD readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { u32 me = get_local_id(0); u32 SIZEY = OUT_WG / OUT_SIZEX; @@ -332,8 +334,8 @@ void readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { u32 fftMiddleOut_y_incr = G_W; // The increment to next fftMiddleOut y value u32 chunk_y_incr = fftMiddleOut_y_incr / SIZEY; // The increment to next fftMiddleOut chunk_y value for (i32 i = 0; i < NW; ++i) { -// u32 fftMiddleOut_y = i * G_W + me; // The fftMiddleOut y value -// u32 chunk_y = fftMiddleOut_y / SIZEY; // The fftMiddleOut chunk_y value +// u32 fftMiddleOut_y = i * G_W + me; // The fftMiddleOut y value +// u32 chunk_y = fftMiddleOut_y / SIZEY; // The fftMiddleOut chunk_y value u[i] = NTLOAD(in[chunk_y * (MIDDLE * OUT_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleOutLine did chunk_y += chunk_y_incr; } @@ -367,3 +369,303 @@ void readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { #endif } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 || NTT_GF31 + +void OVERLOAD writeCarryFusedLine(F2 *u, P(F2) out, u32 line) { +#if PAD_SIZE > 0 + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + out += line * WIDTH + line * PAD_SIZE + line / SMALL_HEIGHT * BIG_PAD_SIZE + (u32) get_local_id(0); // One pad every line + a big pad every SMALL_HEIGHT lines + for (u32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W], u[i]); } +#else + out += line * WIDTH + (u32) get_local_id(0); + for (u32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W], u[i]); } +#endif +} + +void OVERLOAD readMiddleInLine(F2 *u, CP(F2) in, u32 y, u32 x) { +#if PAD_SIZE > 0 + // Each work group reads successive y's which increments by one pad size. + // Rather than having u[i] also increment by one, we choose a larger pad increment + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + in += y * WIDTH + y * PAD_SIZE + (y / SMALL_HEIGHT) * BIG_PAD_SIZE + x; + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * (SMALL_HEIGHT * (WIDTH + PAD_SIZE) + BIG_PAD_SIZE)]); } +#else + in += y * WIDTH + x; + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SMALL_HEIGHT * WIDTH]); } +#endif +} + +void OVERLOAD writeMiddleInLine (P(F2) out, F2 *u, u32 chunk_y, u32 chunk_x) +{ +#if PAD_SIZE > 0 + u32 SIZEY = IN_WG / IN_SIZEX; + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + + out += chunk_y * (MIDDLE * IN_WG + PAD_SIZE) + // Write y chunks after middle chunks and a pad + chunk_x * (SMALL_HEIGHT * MIDDLE * IN_SIZEX + // num_y_chunks * (MIDDLE * IN_WG + PAD_SIZE) + SMALL_HEIGHT / SIZEY * PAD_SIZE + BIG_PAD_SIZE); + // = SMALL_HEIGHT / SIZEY * (MIDDLE * IN_WG + PAD_SIZE) + // = SMALL_HEIGHT / (IN_WG / IN_SIZEX) * (MIDDLE * IN_WG + PAD_SIZE) + // = SMALL_HEIGHT * MIDDLE * IN_SIZEX + SMALL_HEIGHT / SIZEY * PAD_SIZE + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * IN_WG], u[i]); } +#else + // Output data such that readCarryFused lines are packed tightly together. No padding. + out += chunk_y * MIDDLE * IN_WG + // Write y chunks after middles + chunk_x * MIDDLE * SMALL_HEIGHT * IN_SIZEX; // num_y_chunks * IN_WG = SMALL_HEIGHT / SIZEY * MIDDLE * IN_WG + // = MIDDLE * SMALL_HEIGHT / (IN_WG / IN_SIZEX) * IN_WG + // = MIDDLE * SMALL_HEIGHT * IN_SIZEX + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * IN_WG], u[i]); } +#endif +} + +// Read a line for tailFused or fftHin +// This reads partially transposed data as written by fftMiddleIn +void OVERLOAD readTailFusedLine(CP(F2) in, F2 *u, u32 line, u32 me) { + u32 SIZEY = IN_WG / IN_SIZEX; +#if PAD_SIZE > 0 + // Adjust in pointer based on the x value used in writeMiddleInLine + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + u32 fftMiddleIn_x = line % WIDTH; // The fftMiddleIn x value + u32 chunk_x = fftMiddleIn_x / IN_SIZEX; // The fftMiddleIn chunk_x value + in += chunk_x * (SMALL_HEIGHT * MIDDLE * IN_SIZEX + SMALL_HEIGHT / SIZEY * PAD_SIZE + BIG_PAD_SIZE); // Adjust in pointer the same way writeMiddleInLine did + u32 x_within_in_wg = fftMiddleIn_x % IN_SIZEX; // There were IN_SIZEX x values within IN_WG + in += x_within_in_wg * SIZEY; // Adjust in pointer the same way writeMiddleInLine wrote x values within IN_WG + + // Adjust in pointer based on the i value used in writeMiddleInLine + u32 fftMiddleIn_i = line / WIDTH; // The i in fftMiddleIn's u[i] + in += fftMiddleIn_i * IN_WG; // Adjust in pointer the same way writeMiddleInLine did + // Adjust in pointer based on the y value used in writeMiddleInLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleIn_y = me; // The i=0 fftMiddleIn y value + u32 chunk_y = fftMiddleIn_y / SIZEY; // The i=0 fftMiddleIn chunk_y value + u32 fftMiddleIn_y_incr = G_H; // The increment to next fftMiddleIn y value + u32 chunk_y_incr = fftMiddleIn_y_incr / SIZEY; // The increment to next fftMiddleIn chunk_y value + for (i32 i = 0; i < NH; ++i) { + u[i] = NTLOAD(in[chunk_y * (MIDDLE * IN_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleInLine did + chunk_y += chunk_y_incr; + } +#else // Read data that was not rotated or padded + // Adjust in pointer based on the x value used in writeMiddleInLine + u32 fftMiddleIn_x = line % WIDTH; // The fftMiddleIn x value + u32 chunk_x = fftMiddleIn_x / IN_SIZEX; // The fftMiddleIn chunk_x value + in += chunk_x * (SMALL_HEIGHT * MIDDLE * IN_SIZEX); // Adjust in pointer the same way writeMiddleInLine did + u32 x_within_in_wg = fftMiddleIn_x % IN_SIZEX; // There were IN_SIZEX x values within IN_WG + in += x_within_in_wg * SIZEY; // Adjust in pointer the same way writeMiddleInLine wrote x values within IN_WG + // Adjust in pointer based on the i value used in writeMiddleInLine + u32 fftMiddleIn_i = line / WIDTH; // The i in fftMiddleIn's u[i] + in += fftMiddleIn_i * IN_WG; // Adjust in pointer the same way writeMiddleInLine did + // Adjust in pointer based on the y value used in writeMiddleInLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleIn_y = me; // The i=0 fftMiddleIn y value + u32 chunk_y = fftMiddleIn_y / SIZEY; // The i=0 fftMiddleIn chunk_y value + u32 fftMiddleIn_y_incr = G_H; // The increment to next fftMiddleIn y value + u32 chunk_y_incr = fftMiddleIn_y_incr / SIZEY; // The increment to next fftMiddleIn chunk_y value + for (i32 i = 0; i < NH; ++i) { + u32 fftMiddleIn_y = i * G_H + me; // The fftMiddleIn y value + u32 chunk_y = fftMiddleIn_y / SIZEY; // The fftMiddleIn chunk_y value + u[i] = NTLOAD(in[chunk_y * (MIDDLE * IN_WG)]); // Adjust in pointer the same way writeMiddleInLine did + chunk_y += chunk_y_incr; + } +#endif +} + +void OVERLOAD writeTailFusedLine(F2 *u, P(F2) out, u32 line, u32 me) { +#if PAD_SIZE > 0 +#if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + out += line * (SMALL_HEIGHT + PAD_SIZE) + line / MIDDLE * BIG_PAD_SIZE + me; // Pad every output line plus every MIDDLE +#else + out += line * (SMALL_HEIGHT + PAD_SIZE) + me; // Pad every output line +#endif + for (u32 i = 0; i < NH; ++i) { NTSTORE(out[i * G_H], u[i]); } +#else // No padding + out += line * SMALL_HEIGHT + me; + for (u32 i = 0; i < NH; ++i) { NTSTORE(out[i * G_H], u[i]); } +#endif +} + +void OVERLOAD readMiddleOutLine(F2 *u, CP(F2) in, u32 y, u32 x) { +#if PAD_SIZE > 0 +#if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 + // Each u[i] increments by one pad size. + // Rather than each work group reading successive y's also increment by one, we choose a larger pad increment. + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + in += y * MIDDLE * (SMALL_HEIGHT + PAD_SIZE) + y * BIG_PAD_SIZE + x; +#else + in += y * MIDDLE * (SMALL_HEIGHT + PAD_SIZE) + x; +#endif + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * (SMALL_HEIGHT + PAD_SIZE)]); } +#else // No rotation, might be better on nVidia cards + in += y * MIDDLE * SMALL_HEIGHT + x; + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SMALL_HEIGHT]); } +#endif +} + +void OVERLOAD writeMiddleOutLine (P(F2) out, F2 *u, u32 chunk_y, u32 chunk_x) +{ +#if PAD_SIZE > 0 + u32 SIZEY = OUT_WG / OUT_SIZEX; + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + + out += chunk_y * (MIDDLE * OUT_WG + PAD_SIZE) + // Write y chunks after middle chunks and a pad + chunk_x * (WIDTH * MIDDLE * OUT_SIZEX + // num_y_chunks * (MIDDLE * OUT_WG + PAD_SIZE) + WIDTH / SIZEY * PAD_SIZE + BIG_PAD_SIZE);// = WIDTH / SIZEY * (MIDDLE * OUT_WG + PAD_SIZE) + // = WIDTH / (OUT_WG / OUT_SIZEX) * (MIDDLE * OUT_WG + PAD_SIZE) + // = WIDTH * MIDDLE * OUT_SIZEX + WIDTH / SIZEY * PAD_SIZE + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * OUT_WG], u[i]); } +#else + // Output data such that readCarryFused lines are packed tightly together. No padding. + out += chunk_y * MIDDLE * OUT_WG + // Write y chunks after middles + chunk_x * MIDDLE * WIDTH * OUT_SIZEX; // num_y_chunks * OUT_WG = WIDTH / SIZEY * MIDDLE * OUT_WG + // = MIDDLE * WIDTH / (OUT_WG / OUT_SIZEX) * OUT_WG + // = MIDDLE * WIDTH * OUT_SIZEX + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * OUT_WG], u[i]); } +#endif +} + +void OVERLOAD readCarryFusedLine(CP(F2) in, F2 *u, u32 line) { + u32 me = get_local_id(0); + u32 SIZEY = OUT_WG / OUT_SIZEX; +#if PAD_SIZE > 0 + // Adjust in pointer based on the x value used in writeMiddleOutLine + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + u32 fftMiddleOut_x = line % SMALL_HEIGHT; // The fftMiddleOut x value + u32 chunk_x = fftMiddleOut_x / OUT_SIZEX; // The fftMiddleOut chunk_x value + in += chunk_x * (WIDTH * MIDDLE * OUT_SIZEX + WIDTH / SIZEY * PAD_SIZE + BIG_PAD_SIZE); // Adjust in pointer the same way writeMiddleOutLine did + u32 x_within_out_wg = fftMiddleOut_x % OUT_SIZEX; // There were OUT_SIZEX x values within OUT_WG + in += x_within_out_wg * SIZEY; // Adjust in pointer the same way writeMiddleOutLine wrote x values within OUT_WG + // Adjust in pointer based on the i value used in writeMiddleOutLine + u32 fftMiddleOut_i = line / SMALL_HEIGHT; // The i in fftMiddleOut's u[i] + in += fftMiddleOut_i * OUT_WG; // Adjust in pointer the same way writeMiddleOutLine did + // Adjust in pointer based on the y value used in writeMiddleOutLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleOut_y = me; // The i=0 fftMiddleOut y value + u32 chunk_y = fftMiddleOut_y / SIZEY; // The i=0 fftMiddleOut chunk_y value + u32 fftMiddleOut_y_incr = G_W; // The increment to next fftMiddleOut y value + u32 chunk_y_incr = fftMiddleOut_y_incr / SIZEY; // The increment to next fftMiddleOut chunk_y value + for (i32 i = 0; i < NW; ++i) { + u[i] = NTLOAD(in[chunk_y * (MIDDLE * OUT_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleOutLine did + chunk_y += chunk_y_incr; + } +#else // Read data that was not rotated or padded + // Adjust in pointer based on the x value used in writeMiddleOutLine + u32 fftMiddleOut_x = line % SMALL_HEIGHT; // The fftMiddleOut x value + u32 chunk_x = fftMiddleOut_x / OUT_SIZEX; // The fftMiddleOut chunk_x value + in += chunk_x * MIDDLE * WIDTH * OUT_SIZEX; // Adjust in pointer the same way writeMiddleOutLine did + u32 x_within_out_wg = fftMiddleOut_x % OUT_SIZEX; // There were OUT_SIZEX x values within OUT_WG + in += x_within_out_wg * SIZEY; // Adjust in pointer the same way writeMiddleOutLine wrote x values with OUT_WG + // Adjust in pointer based on the i value used in writeMiddleOutLine + u32 fftMiddleOut_i = line / SMALL_HEIGHT; // The i in fftMiddleOut's u[i] + in += fftMiddleOut_i * OUT_WG; // Adjust in pointer the same way writeMiddleOutLine did + // Adjust in pointer based on the y value used in writeMiddleOutLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleOut_y = me; // The i=0 fftMiddleOut y value + u32 chunk_y = fftMiddleOut_y / SIZEY; // The i=0 fftMiddleOut chunk_y value + u32 fftMiddleOut_y_incr = G_W; // The increment to next fftMiddleOut y value + u32 chunk_y_incr = fftMiddleOut_y_incr / SIZEY; // The increment to next fftMiddleOut chunk_y value + for (i32 i = 0; i < NW; ++i) { + u[i] = NTLOAD(in[chunk_y * MIDDLE * OUT_WG]); // Adjust in pointer the same way writeMiddleOutLine did + chunk_y += chunk_y_incr; + } +#endif +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Since F2 and GF31 are the same size we can simply call the floats based code + +void OVERLOAD writeCarryFusedLine(GF31 *u, P(GF31) out, u32 line) { + writeCarryFusedLine((F2 *) u, (P(F2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleInLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF31) out, GF31 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleInLine ((P(F2)) out, (F2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readTailFusedLine(CP(GF31) in, GF31 *u, u32 line, u32 me) { + readTailFusedLine((CP(F2)) in, (F2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF31 *u, P(GF31) out, u32 line, u32 me) { + writeTailFusedLine((F2 *) u, (P(F2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleOutLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF31) out, GF31 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleOutLine ((P(F2)) out, (F2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readCarryFusedLine(CP(GF31) in, GF31 *u, u32 line) { + readCarryFusedLine((CP(F2)) in, (F2 *) u, line); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Since T2 and GF61 are the same size we can simply call the doubles based code + +void OVERLOAD writeCarryFusedLine(GF61 *u, P(GF61) out, u32 line) { + writeCarryFusedLine((T2 *) u, (P(T2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleInLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF61) out, GF61 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleInLine ((P(T2)) out, (T2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readTailFusedLine(CP(GF61) in, GF61 *u, u32 line, u32 me) { + readTailFusedLine((CP(T2)) in, (T2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF61 *u, P(GF61) out, u32 line, u32 me) { + writeTailFusedLine((T2 *) u, (P(T2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleOutLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF61) out, GF61 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleOutLine ((P(T2)) out, (T2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readCarryFusedLine(CP(GF61) in, GF61 *u, u32 line) { + readCarryFusedLine((CP(T2)) in, (T2 *) u, line); +} + +#endif diff --git a/src/cl/selftest.cl b/src/cl/selftest.cl index 7e42b217..3564f131 100644 --- a/src/cl/selftest.cl +++ b/src/cl/selftest.cl @@ -152,6 +152,7 @@ KERNEL(32) testTime(int what, global i64* io) { #endif } +#if FFT_FP64 KERNEL(256) testFFT3(global double2* io) { T2 u[4]; @@ -243,24 +244,12 @@ KERNEL(256) testFFT13(global double2* io) { } } -KERNEL(256) testTrig(global double2* out) { - for (i32 k = get_global_id(0); k < ND / 8; k += get_global_size(0)) { -#if 0 - double angle = M_PI / (ND / 2) * k; - out[k] = U2(cos(angle), -sin(angle)); -#else - out[k] = slowTrig_N(k, ND/8); -#endif - } -} - -KERNEL(256) testFFT(global double2* io) { -#define SIZE 16 - double2 u[SIZE]; +KERNEL(256) testFFT14(global double2* io) { + double2 u[14]; if (get_global_id(0) == 0) { - for (int i = 0; i < SIZE; ++i) { u[i] = io[i]; } - fft16(u); - for (int i = 0; i < SIZE; ++i) { io[i] = u[i]; } + for (int i = 0; i < 14; ++i) { u[i] = io[i]; } + fft14(u); + for (int i = 0; i < 14; ++i) { io[i] = u[i]; } } } @@ -273,11 +262,25 @@ KERNEL(256) testFFT15(global double2* io) { } } -KERNEL(256) testFFT14(global double2* io) { - double2 u[14]; +KERNEL(256) testFFT(global double2* io) { +#define SIZE 16 + double2 u[SIZE]; if (get_global_id(0) == 0) { - for (int i = 0; i < 14; ++i) { u[i] = io[i]; } - fft14(u); - for (int i = 0; i < 14; ++i) { io[i] = u[i]; } + for (int i = 0; i < SIZE; ++i) { u[i] = io[i]; } + fft16(u); + for (int i = 0; i < SIZE; ++i) { io[i] = u[i]; } } } + +KERNEL(256) testTrig(global double2* out) { + for (i32 k = get_global_id(0); k < ND / 8; k += get_global_size(0)) { +#if 0 + double angle = M_PI / (ND / 2) * k; + out[k] = U2(cos(angle), -sin(angle)); +#else + out[k] = slowTrig_N(k, ND/8); +#endif + } +} + +#endif diff --git a/src/cl/tailmul.cl b/src/cl/tailmul.cl index 8266f411..d5421b10 100644 --- a/src/cl/tailmul.cl +++ b/src/cl/tailmul.cl @@ -5,54 +5,45 @@ #include "trig.cl" #include "fftheight.cl" -// Why does this alternate implementation work? Let t' be the conjugate of t and note that t*t' = 1. -// Now consider these lines from the original implementation (comments appear alongside): -// b = mul_by_conjugate(b, t); -// X2(a, b); a + bt', a - bt' -// d = mul_by_conjugate(d, t); -// X2(c, d); c + dt', c - dt' -// a = mul(a, c); (a+bt')(c+dt') = ac + bct' + adt' + bdt'^2 -// b = mul(b, d); (a-bt')(c-dt') = ac - bct' - adt' + bdt'^2 -// X2(a, b); 2ac + 2bdt'^2, 2bct' + 2adt' -// b = mul(b, t); 2bc + 2ad - -void onePairMul(T2* pa, T2* pb, T2* pc, T2* pd, T2 conjugate_t_squared) { +#if FFT_FP64 + +// Handle the final multiplication step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. + +void OVERLOAD onePairMul(T2* pa, T2* pb, T2* pc, T2* pd, T2 t_squared) { T2 a = *pa, b = *pb, c = *pc, d = *pd; X2conjb(a, b); X2conjb(c, d); - T2 tmp = a; + *pa = cfma(a, c, cmul(cmul(b, d), -t_squared)); + *pb = cfma(b, c, cmul(a, d)); - a = cfma(a, c, cmul(cmul(b, d), conjugate_t_squared)); - b = cfma(b, c, cmul(tmp, d)); + X2_conjb(*pa, *pb); - X2conja(a, b); - - *pa = a; - *pb = b; + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); } -void pairMul(u32 N, T2 *u, T2 *v, T2 *p, T2 *q, T2 base_squared, bool special) { +void OVERLOAD pairMul(u32 N, T2 *u, T2 *v, T2 *p, T2 *q, T2 base_squared, bool special) { u32 me = get_local_id(0); for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { if (special && i == 0 && me == 0) { - u[i] = conjugate(2 * foo2(u[i], p[i])); - v[i] = 4 * cmul(conjugate(v[i]), conjugate(q[i])); + u[i] = SWAP_XY(2 * foo2(u[i], p[i])); + v[i] = SWAP_XY(4 * cmul(v[i], q[i])); } else { - onePairMul(&u[i], &v[i], &p[i], &q[i], -base_squared); + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); } if (N == NH) { - onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], base_squared); + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], -base_squared); } T2 new_base_squared = mul_t4(base_squared); - onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], -new_base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); if (N == NH) { - onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], new_base_squared); + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], -new_base_squared); } } } @@ -100,12 +91,7 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { T2 trig = slowTrig_N(line1 + me * H, ND / NH); - if (line1) { - reverseLine(G_H, lds, v); - reverseLine(G_H, lds, q); - pairMul(NH, u, v, p, q, trig, false); - reverseLine(G_H, lds, v); - } else { + if (line1 == 0) { reverse(G_H, lds, u + NH/2, true); reverse(G_H, lds, p + NH/2, true); pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); @@ -116,6 +102,11 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { reverse(G_H, lds, q + NH/2, false); pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); } bar(); @@ -125,3 +116,379 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { writeTailFusedLine(v, out, memline2, me); writeTailFusedLine(u, out, memline1, me); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Handle the final multiplication step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. + +void OVERLOAD onePairMul(F2* pa, F2* pb, F2* pc, F2* pd, F2 t_squared) { + F2 a = *pa, b = *pb, c = *pc, d = *pd; + X2conjb(a, b); + X2conjb(c, d); + *pa = cfma(a, c, cmul(cmul(b, d), -t_squared)); + *pb = cfma(b, c, cmul(a, d)); + X2_conjb(*pa, *pb); + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); +} + +void OVERLOAD pairMul(u32 N, F2 *u, F2 *v, F2 *p, F2 *q, F2 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(2 * foo2(u[i], p[i])); + v[i] = SWAP_XY(4 * cmul(v[i], q[i])); + } else { + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); + } + + if (N == NH) { + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], -base_squared); + } + + F2 new_base_squared = mul_t4(base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); + + if (N == NH) { + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], -new_base_squared); + } + } +} + +KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT]; + + CP(F2) inF2 = (CP(F2)) in; + CP(F2) aF2 = (CP(F2)) a; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH], v[NH]; + F2 p[NH], q[NH]; + + u32 H = ND / SMALL_HEIGHT; + + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(inF2, u, line1, me); + readTailFusedLine(inF2, v, line2, me); + +#if MUL_LOW + read(G_H, NH, p, aF2, memline1 * SMALL_HEIGHT); + read(G_H, NH, q, aF2, memline2 * SMALL_HEIGHT); + fft_HEIGHT(lds, u, smallTrigF2); + bar(); + fft_HEIGHT(lds, v, smallTrigF2); +#else + readTailFusedLine(aF2, p, line1, me); + readTailFusedLine(aF2, q, line2, me); + fft_HEIGHT(lds, u, smallTrigF2); + bar(); + fft_HEIGHT(lds, v, smallTrigF2); + bar(); + fft_HEIGHT(lds, p, smallTrigF2); + bar(); + fft_HEIGHT(lds, q, smallTrigF2); +#endif + + F2 trig = slowTrig_N(line1 + me * H, ND / NH); + + if (line1 == 0) { + reverse(G_H, lds, u + NH/2, true); + reverse(G_H, lds, p + NH/2, true); + pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + F2 trig2 = cmulFancy(trig, TAILT); + reverse(G_H, lds, v + NH/2, false); + reverse(G_H, lds, q + NH/2, false); + pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrigF2); + bar(); + fft_HEIGHT(lds, u, smallTrigF2); + writeTailFusedLine(v, outF2, memline2, me); + writeTailFusedLine(u, outF2, memline1, me); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD onePairMul(GF31* pa, GF31* pb, GF31* pc, GF31* pd, GF31 t_squared) { + GF31 a = *pa, b = *pb, c = *pc, d = *pd; + + X2conjb(a, b); + X2conjb(c, d); + + *pa = sub(cmul(a, c), cmul(cmul(b, d), t_squared)); + *pb = add(cmul(b, c), cmul(a, d)); + + X2_conjb(*pa, *pb); + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); +} + +void OVERLOAD pairMul(u32 N, GF31 *u, GF31 *v, GF31 *p, GF31 *q, GF31 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo2(u[i], p[i]))); + v[i] = SWAP_XY(shl(cmul(v[i], q[i]), 2)); + } else { + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); + } + + if (N == NH) { + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], neg(base_squared)); + } + + GF31 new_base_squared = mul_t4(base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); + + if (N == NH) { + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], neg(new_base_squared)); + } + } +} + +KERNEL(G_H) tailMulGF31(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + CP(GF31) a31 = (CP(GF31)) (a + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH], v[NH]; + GF31 p[NH], q[NH]; + + u32 H = ND / SMALL_HEIGHT; + + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in31, u, line1, me); + readTailFusedLine(in31, v, line2, me); + +#if MUL_LOW + read(G_H, NH, p, a31, memline1 * SMALL_HEIGHT); + read(G_H, NH, q, a31, memline2 * SMALL_HEIGHT); + fft_HEIGHT(lds, u, smallTrig31); + bar(); + fft_HEIGHT(lds, v, smallTrig31); +#else + readTailFusedLine(a31, p, line1, me); + readTailFusedLine(a31, q, line2, me); + fft_HEIGHT(lds, u, smallTrig31); + bar(); + fft_HEIGHT(lds, v, smallTrig31); + bar(); + fft_HEIGHT(lds, p, smallTrig31); + bar(); + fft_HEIGHT(lds, q, smallTrig31); +#endif + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS >= 1 + GF31 trig = smallTrig31[height_trigs + me]; // Trig values for line zero, should be cached +#if SINGLE_WIDE + GF31 mult = smallTrig31[height_trigs + G_H + line1]; +#else + GF31 mult = smallTrig31[height_trigs + G_H + line1 * 2]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF31 trig = NTLOAD(smallTrig31[height_trigs + line1*G_H + me]); +#else + GF31 trig = NTLOAD(smallTrig31[height_trigs + line1*2*G_H + me]); +#endif +#endif + + if (line1 == 0) { + reverse(G_H, lds, u + NH/2, true); + reverse(G_H, lds, p + NH/2, true); + pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + GF31 trig2 = cmul(trig, TAILTGF31); + reverse(G_H, lds, v + NH/2, false); + reverse(G_H, lds, q + NH/2, false); + pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig31); + bar(); + fft_HEIGHT(lds, u, smallTrig31); + writeTailFusedLine(v, out31, memline2, me); + writeTailFusedLine(u, out31, memline1, me); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD onePairMul(GF61* pa, GF61* pb, GF61* pc, GF61* pd, GF61 t_squared) { + GF61 a = *pa, b = *pb, c = *pc, d = *pd; + + X2conjb(a, b); + X2conjb(c, d); + GF61 e = subq(cmul(a, c), cmul(cmul(b, d), t_squared), 2); + GF61 f = addq(cmul(b, c), cmul(a, d)); + X2s_conjb(&e, &f, 4); + *pa = SWAP_XY(e), *pb = SWAP_XY(f); +} + +void OVERLOAD pairMul(u32 N, GF61 *u, GF61 *v, GF61 *p, GF61 *q, GF61 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo2(u[i], p[i]))); + v[i] = SWAP_XY(shl(cmul(v[i], q[i]), 2)); + } else { + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); + } + + if (N == NH) { + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], neg(base_squared)); + } + + GF61 new_base_squared = mul_t4(base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); + + if (N == NH) { + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], neg(new_base_squared)); + } + } +} + +KERNEL(G_H) tailMulGF61(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + CP(GF61) a61 = (CP(GF61)) (a + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH], v[NH]; + GF61 p[NH], q[NH]; + + u32 H = ND / SMALL_HEIGHT; + + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in61, u, line1, me); + readTailFusedLine(in61, v, line2, me); + +#if MUL_LOW + read(G_H, NH, p, a61, memline1 * SMALL_HEIGHT); + read(G_H, NH, q, a61, memline2 * SMALL_HEIGHT); + fft_HEIGHT(lds, u, smallTrig61); + bar(); + fft_HEIGHT(lds, v, smallTrig61); +#else + readTailFusedLine(a61, p, line1, me); + readTailFusedLine(a61, q, line2, me); + fft_HEIGHT(lds, u, smallTrig61); + bar(); + fft_HEIGHT(lds, v, smallTrig61); + bar(); + fft_HEIGHT(lds, p, smallTrig61); + bar(); + fft_HEIGHT(lds, q, smallTrig61); +#endif + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS >= 1 + GF61 trig = smallTrig61[height_trigs + me]; // Trig values for line zero, should be cached +#if SINGLE_WIDE + GF61 mult = smallTrig61[height_trigs + G_H + line1]; +#else + GF61 mult = smallTrig61[height_trigs + G_H + line1 * 2]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF61 trig = NTLOAD(smallTrig61[height_trigs + line1*G_H + me]); +#else + GF61 trig = NTLOAD(smallTrig61[height_trigs + line1*2*G_H + me]); +#endif +#endif + + if (line1 == 0) { + reverse(G_H, lds, u + NH/2, true); + reverse(G_H, lds, p + NH/2, true); + pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + GF61 trig2 = cmul(trig, TAILTGF61); + reverse(G_H, lds, v + NH/2, false); + reverse(G_H, lds, q + NH/2, false); + pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig61); + bar(); + fft_HEIGHT(lds, u, smallTrig61); + writeTailFusedLine(v, out61, memline2, me); + writeTailFusedLine(u, out61, memline1, me); +} + +#endif diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index 30105619..cb60abc5 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -23,52 +23,49 @@ #define SINGLE_WIDE TAIL_KERNELS < 2 // Old single-wide tailSquare vs. new double-wide tailSquare #define SINGLE_KERNEL (TAIL_KERNELS & 1) == 0 // TailSquare uses a single kernel vs. two kernels -// Why does this alternate implementation work? Let t' be the conjugate of t and note that t*t' = 1. -// Now consider these lines from the original implementation (comments appear alongside): -// b = mul_by_conjugate(b, t); bt' -// X2(a, b); a + bt', a - bt' -// a = sq(a); a^2 + 2abt' + (bt')^2 -// b = sq(b); a^2 - 2abt' + (bt')^2 -// X2(a, b); 2a^2 + 2(bt')^2, 4abt' -// b = mul(b, t); 4ab - -void onePairSq(T2* pa, T2* pb, T2 conjugate_t_squared) { +#if FFT_FP64 + +// Handle the final squaring step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. +void OVERLOAD onePairSq(T2* pa, T2* pb, T2 t_squared) { T2 a = *pa; T2 b = *pb; // X2conjb(a, b); // *pb = mul2(cmul(a, b)); -// *pa = csqa(a, cmul(csq(b), conjugate_t_squared)); -// X2conja(*pa, *pb); +// *pa = csqa(a, cmul(csq(b), -t_squared)); +// X2_conjb(*pa, *pb); +// *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb) // Less readable version of the above that saves one complex add by using FMA instructions X2conjb(a, b); - T2 minusnewb = mulminus2(cmul(a, b)); // -newb = -2ab - *pb = csqa(a, cfma(csq(b), conjugate_t_squared, minusnewb)); // final b = newa - newb = a^2 + (bt')^2 - newb - (*pa).x = fma(-2.0, minusnewb.x, (*pb).x); // final a = newa + newb = finalb + 2 * newb - (*pa).y = fma(2.0, minusnewb.y, -(*pb).y); // conjugate(final a) + T2 twoab = mul2(cmul(a, b)); // 2ab + *pa = csqa(a, cfma(csq(b), -t_squared, twoab)); // final a = a^2 + 2ab - (bt)^2 + (*pb).x = fma(-2.0, twoab.x, (*pa).x); // final b = a^2 - 2ab - (bt)^2 + (*pb).y = fma(2.0, twoab.y, -(*pa).y); // conjugate(final b) + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); } -void pairSq(u32 N, T2 *u, T2 *v, T2 base_squared, bool special) { +void OVERLOAD pairSq(u32 N, T2 *u, T2 *v, T2 base_squared, bool special) { u32 me = get_local_id(0); for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { if (special && i == 0 && me == 0) { - u[i] = 2 * foo(conjugate(u[i])); - v[i] = 4 * csq(conjugate(v[i])); + u[i] = SWAP_XY(2 * foo(u[i])); + v[i] = SWAP_XY(4 * csq(v[i])); } else { - onePairSq(&u[i], &v[i], -base_squared); + onePairSq(&u[i], &v[i], base_squared); } if (N == NH) { - onePairSq(&u[i+NH/2], &v[i+NH/2], base_squared); + onePairSq(&u[i+NH/2], &v[i+NH/2], -base_squared); } T2 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &v[i+NH/4], -new_base_squared); + onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); if (N == NH) { - onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], new_base_squared); + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], -new_base_squared); } } } @@ -136,9 +133,10 @@ KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif #if ZEROHACK_H - fft_HEIGHT(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072), w); + u32 zerohack = (u32) get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrig + zerohack, w); bar(); - fft_HEIGHT(lds + (get_group_id(0) / 131072), v, smallTrig + (get_group_id(0) / 131072), w); + fft_HEIGHT(lds + zerohack, v, smallTrig + zerohack, w); #else fft_HEIGHT(lds, u, smallTrig, w); bar(); @@ -208,17 +206,17 @@ KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #else // Special pairSq for double-wide line 0 -void pairSq2_special(T2 *u, T2 base_squared) { +void OVERLOAD pairSq2_special(T2 *u, T2 base_squared) { u32 me = get_local_id(0); for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { if (i == 0 && me == 0) { - u[0] = 2 * foo(conjugate(u[0])); - u[NH/2] = 4 * csq(conjugate(u[NH/2])); + u[0] = SWAP_XY(2 * foo(u[0])); + u[NH/2] = SWAP_XY(4 * csq(u[NH/2])); } else { - onePairSq(&u[i], &u[NH/2+i], -base_squared); + onePairSq(&u[i], &u[NH/2+i], base_squared); } T2 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], -new_base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); } } @@ -239,10 +237,10 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { u32 me = get_local_id(0); u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). - + // We're going to call the halves "first-half" and "second-half". bool isSecondHalf = me >= G_H; - + u32 line = !isSecondHalf ? line_u : line_v; // Read lines u and v @@ -255,7 +253,8 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif #if ZEROHACK_H - new_fft_HEIGHT2_1(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072), w); + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrig + zerohack, w); #else new_fft_HEIGHT2_1(lds, u, smallTrig, w); #endif @@ -298,9 +297,7 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); pairSq(NH/2, u, u + NH/2, trig, false); - bar(G_H); - // We change the LDS halves we're using in order to enable half-bars revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); } @@ -313,3 +310,877 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { } #endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Handle the final squaring step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. +void OVERLOAD onePairSq(F2* pa, F2* pb, F2 t_squared) { + F2 a = *pa; + F2 b = *pb; + + X2conjb(a, b); + F2 twoab = mul2(cmul(a, b)); // 2ab + *pa = csqa(a, cfma(csq(b), -t_squared, twoab)); // final a = a^2 + 2ab - (bt)^2 + (*pb).x = fma(-2.0f, twoab.x, (*pa).x); // final b = a^2 - 2ab - (bt)^2 + (*pb).y = fma(2.0f, twoab.y, -(*pa).y); // conjugate(final b) + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); +} + +void OVERLOAD pairSq(u32 N, F2 *u, F2 *v, F2 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(2 * foo(u[i])); + v[i] = SWAP_XY(4 * csq(v[i])); + } else { + onePairSq(&u[i], &v[i], base_squared); + } + + if (N == NH) { + onePairSq(&u[i+NH/2], &v[i+NH/2], -base_squared); + } + + F2 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); + + if (N == NH) { + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], -new_base_squared); + } + } +} + +// The kernel tailSquareZero handles the special cases in tailSquare, i.e. the lines 0 and H/2 +// This kernel is launched with 2 workgroups (handling line 0, resp. H/2) +KERNEL(G_H) tailSquareZero(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT / 2]; + F2 u[NH]; + u32 H = ND / SMALL_HEIGHT; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + // This kernel in executed in two workgroups. + u32 which = get_group_id(0); + assert(which < 2); + + u32 line = which ? (H/2) : 0; + u32 me = get_local_id(0); + readTailFusedLine(inF2, u, line, me); + + F2 trig = slowTrig_N(line + me * H, ND / NH); + + fft_HEIGHT(lds, u, smallTrigF2); + reverse(G_H, lds, u + NH/2, !which); + pairSq(NH/2, u, u + NH/2, trig, !which); + reverse(G_H, lds, u + NH/2, !which); + + bar(); + fft_HEIGHT(lds, u, smallTrigF2); + writeTailFusedLine(u, outF2, transPos(line, MIDDLE, WIDTH), me); +} + +#if SINGLE_WIDE + +KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH], v[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); +#else + u32 line1 = get_group_id(0) + 1; + u32 line2 = H - line1; +#endif + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(inF2, u, line1, me); + readTailFusedLine(inF2, v, line2, me); + +#if ZEROHACK_H + u32 zerohack = get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrigF2 + zerohack); + bar(); + fft_HEIGHT(lds + zerohack, v, smallTrigF2 + zerohack); +#else + fft_HEIGHT(lds, u, smallTrigF2); + bar(); + fft_HEIGHT(lds, v, smallTrigF2); +#endif + + // Compute trig values from scratch. Good on GPUs with high DP throughput. +#if TAIL_TRIGS == 2 + F2 trig = slowTrig_N(line1 + me * H, ND / NH); + + // Do a little bit of memory access and a little bit of DP math. Good on a Radeon VII. +#elif TAIL_TRIGS == 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached F2 per line + F2 trig = smallTrigF2[height_trigs + me]; // Trig values for line zero, should be cached + F2 mult = smallTrigF2[height_trigs + G_H + line1]; // Line multiplier + trig = cmulFancy(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + F2 trig = NTLOAD(smallTrigF2[height_trigs + line1*G_H + me]); +#endif + +#if SINGLE_KERNEL + if (line1 == 0) { + // Line 0 is special: it pairs with itself, offseted by 1. + reverse(G_H, lds, u + NH/2, true); + pairSq(NH/2, u, u + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + // Line H/2 also pairs with itself (but without offset). + F2 trig2 = cmulFancy(trig, TAILT); + reverse(G_H, lds, v + NH/2, false); + pairSq(NH/2, v, v + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } + else { +#else + if (1) { +#endif + reverseLine(G_H, lds, v); + pairSq(NH, u, v, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrigF2, w); + bar(); + fft_HEIGHT(lds, u, smallTrigF2, w); + + writeTailFusedLine(v, outF2, memline2, me); + writeTailFusedLine(u, outF2, memline1, me); +} + + +// +// Create a kernel that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// We hope to get better occupancy with the reduced register usage +// + +#else + +// Special pairSq for double-wide line 0 +void OVERLOAD pairSq2_special(F2 *u, F2 base_squared) { + u32 me = get_local_id(0); + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (i == 0 && me == 0) { + u[0] = SWAP_XY(2 * foo(u[0])); + u[NH/2] = SWAP_XY(4 * csq(u[NH/2])); + } else { + onePairSq(&u[i], &u[NH/2+i], base_squared); + } + F2 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); + } +} + +KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line_u = get_group_id(0); + u32 line_v = line_u ? H - line_u : (H / 2); +#else + u32 line_u = get_group_id(0) + 1; + u32 line_v = H - line_u; +#endif + + u32 me = get_local_id(0); + u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). + + // We're going to call the halves "first-half" and "second-half". + bool isSecondHalf = me >= G_H; + + u32 line = !isSecondHalf ? line_u : line_v; + + // Read lines u and v + readTailFusedLine(inF2, u, line, lowMe); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrigF2 + zerohack); +#else + new_fft_HEIGHT2_1(lds, u, smallTrigF2); +#endif + + // Compute trig values from scratch. Good on GPUs with high DP throughput. +#if TAIL_TRIGS == 2 + F2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); + + // Do a little bit of memory access and a little bit of DP math. Good on a Radeon VII. +#elif TAIL_TRIGS == 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached F2 per line + F2 trig = smallTrigF2[height_trigs + lowMe]; // Trig values for line zero, should be cached + F2 mult = smallTrigF2[height_trigs + G_H + line_u*2 + isSecondHalf]; // Two multipliers. One for line u, one for line v. + trig = cmulFancy(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + F2 trig = NTLOAD(smallTrigF2[height_trigs + line_u*G_H*2 + me]); +#endif + + bar(G_H); + +#if SINGLE_KERNEL + // Line 0 and H/2 are special: they pair with themselves, line 0 is offseted by 1. + if (line_u == 0) { + reverse2(lds, u); + pairSq2_special(u, trig); + reverse2(lds, u); + } + else { +#else + if (1) { +#endif + revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); + pairSq(NH/2, u, u + NH/2, trig, false); + bar(G_H); + revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); + } + + bar(G_H); + + new_fft_HEIGHT2_2(lds, u, smallTrigF2); + + // Write lines u and v + writeTailFusedLine(u, outF2, transPos(line, MIDDLE, WIDTH), lowMe); +} + +#endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD onePairSq(GF31* pa, GF31* pb, GF31 t_squared) { + GF31 a = *pa, b = *pb; + + X2conjb(a, b); + GF31 c = sub(csq(a), cmul(csq(b), t_squared)); + GF31 d = mul2(cmul(a, b)); + X2_conjb(c, d); + *pa = SWAP_XY(c), *pb = SWAP_XY(d); +} + +void OVERLOAD pairSq(u32 N, GF31 *u, GF31 *v, GF31 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo(u[i]))); + v[i] = SWAP_XY(shl(csq(v[i]), 2)); + } else { + onePairSq(&u[i], &v[i], base_squared); + } + + if (N == NH) { + onePairSq(&u[i+NH/2], &v[i+NH/2], neg(base_squared)); //GWBUG -- can we write a special onepairsq that expects a base_squared that needs negation? + } + + GF31 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); //GWBUG -- or another special onePairSq that expects mul_t4'ed base_squared + + if (N == NH) { + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], neg(new_base_squared)); //GWBUG -- or another special onePairSq that expects mul_t4'ed and negated base_squared + } + } +} + +// The kernel tailSquareZero handles the special cases in tailSquare, i.e. the lines 0 and H/2 +// This kernel is launched with 2 workgroups (handling line 0, resp. H/2) +KERNEL(G_H) tailSquareZeroGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT / 2]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH]; + u32 H = ND / SMALL_HEIGHT; + + // This kernel in executed in two workgroups. + u32 which = get_group_id(0); + assert(which < 2); + + u32 line = which ? (H/2) : 0; + u32 me = get_local_id(0); + readTailFusedLine(in31, u, line, me); + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS >= 1 + GF31 trig = smallTrig31[height_trigs + me]; +#if SINGLE_WIDE + GF31 mult = smallTrig31[height_trigs + G_H + line]; +#else + GF31 mult = smallTrig31[height_trigs + G_H + which]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF31 trig = NTLOAD(smallTrig31[height_trigs + line*G_H + me]); +#else + GF31 trig = NTLOAD(smallTrig31[height_trigs + which*G_H + me]); +#endif +#endif + + fft_HEIGHT(lds, u, smallTrig31); + reverse(G_H, lds, u + NH/2, !which); + pairSq(NH/2, u, u + NH/2, trig, !which); + reverse(G_H, lds, u + NH/2, !which); + bar(); + fft_HEIGHT(lds, u, smallTrig31); + writeTailFusedLine(u, out31, transPos(line, MIDDLE, WIDTH), me); +} + +#if SINGLE_WIDE + +KERNEL(G_H) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH], v[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); +#else + u32 line1 = get_group_id(0) + 1; + u32 line2 = H - line1; +#endif + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in31, u, line1, me); + readTailFusedLine(in31, v, line2, me); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrig31 + zerohack); + bar(); + fft_HEIGHT(lds + zerohack, v, smallTrig31 + zerohack); +#else + fft_HEIGHT(lds, u, smallTrig31); + bar(); + fft_HEIGHT(lds, v, smallTrig31); +#endif + + // Compute trig values from scratch. Good on GPUs with relatively slow memory. +#if 0 && TAIL_TRIGS == 2 + GF31 trig = slowTrigGF31(line1 + me * H, ND / NH); + + // Do a little bit of memory access and a little bit of math. +#elif TAIL_TRIGS >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF31 per line + GF31 trig = smallTrig31[height_trigs + me]; // Trig values for line zero, should be cached + GF31 mult = smallTrig31[height_trigs + G_H + line1]; // Line multiplier + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF31 trig = NTLOAD(smallTrig31[height_trigs + line1*G_H + me]); +#endif + +#if SINGLE_KERNEL + if (line1 == 0) { + // Line 0 is special: it pairs with itself, offseted by 1. + reverse(G_H, lds, u + NH/2, true); + pairSq(NH/2, u, u + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + // Line H/2 also pairs with itself (but without offset). + GF31 trig2 = cmul(trig, TAILTGF31); + reverse(G_H, lds, v + NH/2, false); + pairSq(NH/2, v, v + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } + else { +#else + if (1) { +#endif + reverseLine(G_H, lds, v); + pairSq(NH, u, v, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig31); + bar(); + fft_HEIGHT(lds, u, smallTrig31); + + writeTailFusedLine(v, out31, memline2, me); + writeTailFusedLine(u, out31, memline1, me); +} + + +// +// Create a kernel that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// We hope to get better occupancy with the reduced register usage +// + +#else + +// Special pairSq for double-wide line 0 +void OVERLOAD pairSq2_special(GF31 *u, GF31 base_squared) { + u32 me = get_local_id(0); + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (i == 0 && me == 0) { + u[0] = SWAP_XY(mul2(foo(u[0]))); + u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); + } else { + onePairSq(&u[i], &u[NH/2+i], base_squared); //GWBUG - why are we only using neg(base squareds) onePairSq could easily compensate for this + } + GF31 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); + } +} + +KERNEL(G_H * 2) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line_u = get_group_id(0); + u32 line_v = line_u ? H - line_u : (H / 2); +#else + u32 line_u = get_group_id(0) + 1; + u32 line_v = H - line_u; +#endif + + u32 me = get_local_id(0); + u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). + + // We're going to call the halves "first-half" and "second-half". + bool isSecondHalf = me >= G_H; + + u32 line = !isSecondHalf ? line_u : line_v; + + // Read lines u and v + readTailFusedLine(in31, u, line, lowMe); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrig31 + zerohack); +#else + new_fft_HEIGHT2_1(lds, u, smallTrig31); +#endif + + // Compute trig values from scratch. Good on GPUs with high MUL throughput?? +#if 0 && TAIL_TRIGS == 2 + GF31 trig = slowTrigGF31(line + H * lowMe, ND / NH * 2); + + // Do a little bit of memory access and a little bit of math. Good on a Radeon VII. +#elif TAIL_TRIGS >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF31 per line + GF31 trig = smallTrig31[height_trigs + lowMe]; // Trig values for line zero, should be cached + GF31 mult = smallTrig31[height_trigs + G_H + line_u*2 + isSecondHalf]; // Two multipliers. One for line u, one for line v. + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF31 trig = NTLOAD(smallTrig31[height_trigs + line_u*G_H*2 + me]); +#endif + + bar(G_H); + +#if SINGLE_KERNEL + // Line 0 and H/2 are special: they pair with themselves, line 0 is offseted by 1. + if (line_u == 0) { + reverse2(lds, u); + pairSq2_special(u, trig); + reverse2(lds, u); + } + else { +#else + if (1) { +#endif + revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); + pairSq(NH/2, u, u + NH/2, trig, false); + bar(G_H); + revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); + } + + bar(G_H); + + new_fft_HEIGHT2_2(lds, u, smallTrig31); + + // Write lines u and v + writeTailFusedLine(u, out31, transPos(line, MIDDLE, WIDTH), lowMe); +} + +#endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared) { + GF61 a = *pa, b = *pb; + + X2conjb(a, b); + GF61 c = subq(csq(a), cmul(csq(b), t_squared), 2); + GF61 d = 2 * cmul(a, b); + X2s_conjb(&c, &d, 4); + *pa = SWAP_XY(c), *pb = SWAP_XY(d); +} + +void OVERLOAD pairSq(u32 N, GF61 *u, GF61 *v, GF61 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo(u[i]))); + v[i] = SWAP_XY(shl(csq(v[i]), 2)); + } else { + onePairSq(&u[i], &v[i], base_squared); + } + + if (N == NH) { + onePairSq(&u[i+NH/2], &v[i+NH/2], neg(base_squared)); //GWBUG -- can we write a special onepairsq that expects a base_squared that needs negation? + } + + GF61 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); //GWBUG -- or another special onePairSq that expects mul_t4'ed base_squared + + if (N == NH) { + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], neg(new_base_squared)); //GWBUG -- or another special onePairSq that expects mul_t4'ed and negated base_squared + } + } +} + +// The kernel tailSquareZero handles the special cases in tailSquare, i.e. the lines 0 and H/2 +// This kernel is launched with 2 workgroups (handling line 0, resp. H/2) +KERNEL(G_H) tailSquareZeroGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT / 2]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH]; + u32 H = ND / SMALL_HEIGHT; + + // This kernel in executed in two workgroups. + u32 which = get_group_id(0); + assert(which < 2); + + u32 line = which ? (H/2) : 0; + u32 me = get_local_id(0); + readTailFusedLine(in61, u, line, me); + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS >= 1 + GF61 trig = smallTrig61[height_trigs + me]; +#if SINGLE_WIDE + GF61 mult = smallTrig61[height_trigs + G_H + line]; +#else + GF61 mult = smallTrig61[height_trigs + G_H + which]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF61 trig = NTLOAD(smallTrig61[height_trigs + line*G_H + me]); +#else + GF61 trig = NTLOAD(smallTrig61[height_trigs + which*G_H + me]); +#endif +#endif + + fft_HEIGHT(lds, u, smallTrig61); + reverse(G_H, lds, u + NH/2, !which); + pairSq(NH/2, u, u + NH/2, trig, !which); + reverse(G_H, lds, u + NH/2, !which); + bar(); + fft_HEIGHT(lds, u, smallTrig61); + writeTailFusedLine(u, out61, transPos(line, MIDDLE, WIDTH), me); +} + +#if SINGLE_WIDE + +KERNEL(G_H) tailSquareGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH], v[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); +#else + u32 line1 = get_group_id(0) + 1; + u32 line2 = H - line1; +#endif + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in61, u, line1, me); + readTailFusedLine(in61, v, line2, me); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrig61 + zerohack); + bar(); + fft_HEIGHT(lds + zerohack, v, smallTrig61 + zerohack); +#else + fft_HEIGHT(lds, u, smallTrig61); + bar(); + fft_HEIGHT(lds, v, smallTrig61); +#endif + + // Compute trig values from scratch. Good on GPUs with relatively slow memory?? +#if 0 && TAIL_TRIGS == 2 + GF61 trig = slowTrigGF61(line1 + me * H, ND / NH); + + // Do a little bit of memory access and a little bit of math. +#elif TAIL_TRIGS >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF61 per line + GF61 trig = smallTrig61[height_trigs + me]; // Trig values for line zero, should be cached + GF61 mult = smallTrig61[height_trigs + G_H + line1]; // Line multiplier + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF61 trig = NTLOAD(smallTrig61[height_trigs + line1*G_H + me]); +#endif + +#if SINGLE_KERNEL + if (line1 == 0) { + // Line 0 is special: it pairs with itself, offseted by 1. + reverse(G_H, lds, u + NH/2, true); + pairSq(NH/2, u, u + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + // Line H/2 also pairs with itself (but without offset). + GF61 trig2 = cmul(trig, TAILTGF61); + reverse(G_H, lds, v + NH/2, false); + pairSq(NH/2, v, v + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } + else { +#else + if (1) { +#endif + reverseLine(G_H, lds, v); + pairSq(NH, u, v, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig61); + bar(); + fft_HEIGHT(lds, u, smallTrig61); + + writeTailFusedLine(v, out61, memline2, me); + writeTailFusedLine(u, out61, memline1, me); +} + + +// +// Create a kernel that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// We hope to get better occupancy with the reduced register usage +// + +#else + +// Special pairSq for double-wide line 0 +void OVERLOAD pairSq2_special(GF61 *u, GF61 base_squared) { + u32 me = get_local_id(0); + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (i == 0 && me == 0) { + u[0] = SWAP_XY(mul2(foo(u[0]))); + u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); + } else { + onePairSq(&u[i], &u[NH/2+i], base_squared); //GWBUG - why are we only using neg(base squareds) onePairSq could easily compensate for this + } + GF61 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); + } +} + +KERNEL(G_H * 2) tailSquareGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line_u = get_group_id(0); + u32 line_v = line_u ? H - line_u : (H / 2); +#else + u32 line_u = get_group_id(0) + 1; + u32 line_v = H - line_u; +#endif + + u32 me = get_local_id(0); + u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). + + // We're going to call the halves "first-half" and "second-half". + bool isSecondHalf = me >= G_H; + + u32 line = !isSecondHalf ? line_u : line_v; + + // Read lines u and v + readTailFusedLine(in61, u, line, lowMe); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrig61 + zerohack); +#else + new_fft_HEIGHT2_1(lds, u, smallTrig61); +#endif + + // Compute trig values from scratch. Good on GPUs with high MUL throughput?? +#if 0 && TAIL_TRIGS == 2 + GF61 trig = slowTrigGF61(line + H * lowMe, ND / NH * 2); + + // Do a little bit of memory access and a little bit of math. Good on a Radeon VII. +#elif TAIL_TRIGS >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF61 per line + GF61 trig = smallTrig61[height_trigs + lowMe]; // Trig values for line zero, should be cached + GF61 mult = smallTrig61[height_trigs + G_H + line_u*2 + isSecondHalf]; // Two multipliers. One for line u, one for line v. + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF61 trig = NTLOAD(smallTrig61[height_trigs + line_u*G_H*2 + me]); +#endif + + bar(G_H); + +#if SINGLE_KERNEL + // Line 0 and H/2 are special: they pair with themselves, line 0 is offseted by 1. + if (line_u == 0) { + reverse2(lds, u); + pairSq2_special(u, trig); + reverse2(lds, u); + } + else { +#else + if (1) { +#endif + revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); + pairSq(NH/2, u, u + NH/2, trig, false); + bar(G_H); + revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); + } + + bar(G_H); + + new_fft_HEIGHT2_2(lds, u, smallTrig61); + + // Write lines u and v + writeTailFusedLine(u, out61, transPos(line, MIDDLE, WIDTH), lowMe); +} + +#endif + +#endif diff --git a/src/cl/tailutil.cl b/src/cl/tailutil.cl index a33f280a..8d6d6567 100644 --- a/src/cl/tailutil.cl +++ b/src/cl/tailutil.cl @@ -2,7 +2,9 @@ #include "math.cl" -void reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { +#if FFT_FP64 + +void OVERLOAD reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { u32 me = get_local_id(0); u32 revMe = WG - 1 - me + bump; @@ -24,7 +26,7 @@ void reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } } -void reverseLine(u32 WG, local T2 *lds2, T2 *u) { +void OVERLOAD reverseLine(u32 WG, local T2 *lds2, T2 *u) { u32 me = get_local_id(0); u32 revMe = WG - 1 - me; @@ -38,7 +40,7 @@ void reverseLine(u32 WG, local T2 *lds2, T2 *u) { } // This is used to reverse the second part of a line, and cross the reversed parts between the halves. -void revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) { +void OVERLOAD revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) { u32 me = get_local_id(0); u32 lowMe = me % WG; @@ -51,23 +53,11 @@ void revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) { for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } } -// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) -// which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 -T2 foo2(T2 a, T2 b) { - a = addsub(a); - b = addsub(b); - return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); -} - -// computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) -T2 foo(T2 a) { return foo2(a, a); } - - // // These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) // -void reverse2(local T2 *lds, T2 *u) { +void OVERLOAD reverse2(local T2 *lds, T2 *u) { u32 me = get_local_id(0); // For NH=8, u[0] to u[3] are left unchanged. Write to lds: @@ -98,7 +88,7 @@ void reverse2(local T2 *lds, T2 *u) { // u[2] u[3] // Returned in u[1] // v[3]rev v[2]rev // Returned in u[2] // v[1]rev v[0]rev // Returned in u[3] -void reverseLine2(local T2 *lds, T2 *u) { +void OVERLOAD reverseLine2(local T2 *lds, T2 *u) { u32 me = get_local_id(0); // NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an @@ -134,7 +124,7 @@ void reverseLine2(local T2 *lds, T2 *u) { } // Undo a reverseLine2 -void unreverseLine2(local T2 *lds, T2 *u) { +void OVERLOAD unreverseLine2(local T2 *lds, T2 *u) { u32 me = get_local_id(0); // NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially @@ -170,3 +160,503 @@ void unreverseLine2(local T2 *lds, T2 *u) { for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } #endif } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD reverse(u32 WG, local F2 *lds, F2 *u, bool bump) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me + bump; + + bar(); + +#if NH == 8 + lds[revMe + 0 * WG] = u[3]; + lds[revMe + 1 * WG] = u[2]; + lds[revMe + 2 * WG] = u[1]; + lds[bump ? ((revMe + 3 * WG) % (4 * WG)) : (revMe + 3 * WG)] = u[0]; +#elif NH == 4 + lds[revMe + 0 * WG] = u[1]; + lds[bump ? ((revMe + WG) % (2 * WG)) : (revMe + WG)] = u[0]; +#else +#error +#endif + + bar(); + for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD reverseLine(u32 WG, local F2 *lds2, F2 *u) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me; + + local F2 *lds = lds2 + revMe; + bar(); + for (u32 i = 0; i < NH; ++i) { lds[WG * (NH - 1 - i)] = u[i]; } + + lds = lds2 + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; } +} + +// This is used to reverse the second part of a line, and cross the reversed parts between the halves. +void OVERLOAD revCrossLine(u32 WG, local F2* lds2, F2 *u, u32 n, bool writeSecondHalf) { + u32 me = get_local_id(0); + u32 lowMe = me % WG; + + u32 revLowMe = WG - 1 - lowMe; + + for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; } + + bar(); // we need a full bar because we're crossing halves + + for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } +} + +// +// These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// + +void OVERLOAD reverse2(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + + // For NH=8, u[0] to u[3] are left unchanged. Write to lds: + // u[7]rev u[6]rev + // u[5]rev u[4]rev + // v[7]rev v[6]rev + // v[5]rev v[4]rev + bar(); + for (u32 i = 0; i < NH / 2; ++i) { + u32 j = (i * G_H + me % G_H); + lds[me < G_H ? ((NH/2)*G_H - j) % ((NH/2)*G_H) : NH*G_H-1 - j] = u[NH/2 + i]; + } + // For NH=8, read from lds into u[i]: + // u[4] = u[7]rev v[7]rev + // u[5] = u[6]rev v[6]rev + // u[6] = u[5]rev v[5]rev + // u[7] = u[4]rev v[4]rev + bar(); + lds += me % G_H + (me / G_H) * NH/2 * G_H; + for (u32 i = 0; i < NH / 2; ++i) { u[NH/2 + i] = lds[i * G_H]; } +} + +// Somewhat similar to reverseLine. +// The u values are in threads < G_H, the v values to reverse in threads >= G_H. +// Whereas reverseLine leaves u values alone. This reverseLine moves u values around +// so that pairSq2 can easily operate on pairs. This means for NH = 4, web output: +// u[0] u[1] // Returned in u[0] +// u[2] u[3] // Returned in u[1] +// v[3]rev v[2]rev // Returned in u[2] +// v[1]rev v[0]rev // Returned in u[3] +void OVERLOAD reverseLine2(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an +// unqualified bar() call here. Specifically, the u values are stored in the upper half of lds memory (SMALL_HEIGHT F2 values). +// The v values are stored in the lower half of lds memory (the next SMALL_HEIGHT F2 values). + + if (G_H > WAVEFRONT) bar(); + +// For NH=4, the lds indices (where to write each incoming u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H +// That means saving to lds using index: me < G_H ? me % G_H + i * G_H : 8*G_H-1 - me % G_H - i * G_H + +#if 1 + local F2 *ldsOut = lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { *ldsOut = u[i]; } + + lds += me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[i * 2*G_H]; } +#else + local F *ldsOut = (local F *) lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { ldsOut[0] = u[i].x; ldsOut[NH*2*G_H] = u[i].y; } + + local F *ldsIn = (local F *) lds + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i].x = ldsIn[i * 2*G_H]; u[i].y = ldsIn[NH*2*G_H + i * 2*G_H]; } +#endif +} + +// Undo a reverseLine2 +void OVERLOAD unreverseLine2(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially +// writing to the lds locations that reverseLine2 read from we do not need an initial bar() call here. Also, by reading +// from the lds locations that shufl2 will use (u values in the upper half of lds memory, v values in the lower half of +// lds memory) we can issue a qualified bar() call before calling FFT_HEIGHT2. + +#if 1 + local F2 *ldsOut = lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i]; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + lds += (me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H; + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, lds += ldsInc) { u[i] = *lds; } +#else + local F *ldsOut = (local F *) lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i].x; ldsOut[NH*2*G_H + i * 2*G_H] = u[i].y; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + local F *ldsIn = (local F *) lds + ((me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } +#endif +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD reverse(u32 WG, local GF31 *lds, GF31 *u, bool bump) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me + bump; + + bar(); + +#if NH == 8 + lds[revMe + 0 * WG] = u[3]; + lds[revMe + 1 * WG] = u[2]; + lds[revMe + 2 * WG] = u[1]; + lds[bump ? ((revMe + 3 * WG) % (4 * WG)) : (revMe + 3 * WG)] = u[0]; +#elif NH == 4 + lds[revMe + 0 * WG] = u[1]; + lds[bump ? ((revMe + WG) % (2 * WG)) : (revMe + WG)] = u[0]; +#else +#error +#endif + + bar(); + for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD reverseLine(u32 WG, local GF31 *lds2, GF31 *u) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me; + + local GF31 *lds = lds2 + revMe; + bar(); + for (u32 i = 0; i < NH; ++i) { lds[WG * (NH - 1 - i)] = u[i]; } + + lds = lds2 + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; } +} + +// This is used to reverse the second part of a line, and cross the reversed parts between the halves. +void OVERLOAD revCrossLine(u32 WG, local GF31* lds2, GF31 *u, u32 n, bool writeSecondHalf) { + u32 me = get_local_id(0); + u32 lowMe = me % WG; + + u32 revLowMe = WG - 1 - lowMe; + + for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; } + + bar(); // we need a full bar because we're crossing halves + + for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } +} + +// +// These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// + +void OVERLOAD reverse2(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + + // For NH=8, u[0] to u[3] are left unchanged. Write to lds: + // u[7]rev u[6]rev + // u[5]rev u[4]rev + // v[7]rev v[6]rev + // v[5]rev v[4]rev + bar(); + for (u32 i = 0; i < NH / 2; ++i) { + u32 j = (i * G_H + me % G_H); + lds[me < G_H ? ((NH/2)*G_H - j) % ((NH/2)*G_H) : NH*G_H-1 - j] = u[NH/2 + i]; + } + // For NH=8, read from lds into u[i]: + // u[4] = u[7]rev v[7]rev + // u[5] = u[6]rev v[6]rev + // u[6] = u[5]rev v[5]rev + // u[7] = u[4]rev v[4]rev + bar(); + lds += me % G_H + (me / G_H) * NH/2 * G_H; + for (u32 i = 0; i < NH / 2; ++i) { u[NH/2 + i] = lds[i * G_H]; } +} + +// Somewhat similar to reverseLine. +// The u values are in threads < G_H, the v values to reverse in threads >= G_H. +// Whereas reverseLine leaves u values alone. This reverseLine moves u values around +// so that pairSq2 can easily operate on pairs. This means for NH = 4, web output: +// u[0] u[1] // Returned in u[0] +// u[2] u[3] // Returned in u[1] +// v[3]rev v[2]rev // Returned in u[2] +// v[1]rev v[0]rev // Returned in u[3] +void OVERLOAD reverseLine2(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an +// unqualified bar() call here. Specifically, the u values are stored in the upper half of lds memory (SMALL_HEIGHT GF31 values). +// The v values are stored in the lower half of lds memory (the next SMALL_HEIGHT GF31 values). + + if (G_H > WAVEFRONT) bar(); + +// For NH=4, the lds indices (where to write each incoming u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H +// That means saving to lds using index: me < G_H ? me % G_H + i * G_H : 8*G_H-1 - me % G_H - i * G_H + +#if 1 + local GF31 *ldsOut = lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { *ldsOut = u[i]; } + + lds += me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[i * 2*G_H]; } +#else + local Z61 *ldsOut = (local Z61 *) lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { ldsOut[0] = u[i].x; ldsOut[NH*2*G_H] = u[i].y; } + + local ZF61 *ldsIn = (local T *) lds + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i].x = ldsIn[i * 2*G_H]; u[i].y = ldsIn[NH*2*G_H + i * 2*G_H]; } +#endif +} + +// Undo a reverseLine2 +void OVERLOAD unreverseLine2(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially +// writing to the lds locations that reverseLine2 read from we do not need an initial bar() call here. Also, by reading +// from the lds locations that shufl2 will use (u values in the upper half of lds memory, v values in the lower half of +// lds memory) we can issue a qualified bar() call before calling FFT_HEIGHT2. + +#if 1 + local GF31 *ldsOut = lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i]; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + lds += (me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H; + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, lds += ldsInc) { u[i] = *lds; } +#else + local Z61 *ldsOut = (local T *) lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i].x; ldsOut[NH*2*G_H + i * 2*G_H] = u[i].y; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + local Z61 *ldsIn = (local T *) lds + ((me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } +#endif +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD reverse(u32 WG, local GF61 *lds, GF61 *u, bool bump) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me + bump; + + bar(); + +#if NH == 8 + lds[revMe + 0 * WG] = u[3]; + lds[revMe + 1 * WG] = u[2]; + lds[revMe + 2 * WG] = u[1]; + lds[bump ? ((revMe + 3 * WG) % (4 * WG)) : (revMe + 3 * WG)] = u[0]; +#elif NH == 4 + lds[revMe + 0 * WG] = u[1]; + lds[bump ? ((revMe + WG) % (2 * WG)) : (revMe + WG)] = u[0]; +#else +#error +#endif + + bar(); + for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD reverseLine(u32 WG, local GF61 *lds2, GF61 *u) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me; + + local GF61 *lds = lds2 + revMe; + bar(); + for (u32 i = 0; i < NH; ++i) { lds[WG * (NH - 1 - i)] = u[i]; } + + lds = lds2 + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; } +} + +// This is used to reverse the second part of a line, and cross the reversed parts between the halves. +void OVERLOAD revCrossLine(u32 WG, local GF61* lds2, GF61 *u, u32 n, bool writeSecondHalf) { + u32 me = get_local_id(0); + u32 lowMe = me % WG; + + u32 revLowMe = WG - 1 - lowMe; + + for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; } + + bar(); // we need a full bar because we're crossing halves + + for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } +} + +// +// These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// + +void OVERLOAD reverse2(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + + // For NH=8, u[0] to u[3] are left unchanged. Write to lds: + // u[7]rev u[6]rev + // u[5]rev u[4]rev + // v[7]rev v[6]rev + // v[5]rev v[4]rev + bar(); + for (u32 i = 0; i < NH / 2; ++i) { + u32 j = (i * G_H + me % G_H); + lds[me < G_H ? ((NH/2)*G_H - j) % ((NH/2)*G_H) : NH*G_H-1 - j] = u[NH/2 + i]; + } + // For NH=8, read from lds into u[i]: + // u[4] = u[7]rev v[7]rev + // u[5] = u[6]rev v[6]rev + // u[6] = u[5]rev v[5]rev + // u[7] = u[4]rev v[4]rev + bar(); + lds += me % G_H + (me / G_H) * NH/2 * G_H; + for (u32 i = 0; i < NH / 2; ++i) { u[NH/2 + i] = lds[i * G_H]; } +} + +// Somewhat similar to reverseLine. +// The u values are in threads < G_H, the v values to reverse in threads >= G_H. +// Whereas reverseLine leaves u values alone. This reverseLine moves u values around +// so that pairSq2 can easily operate on pairs. This means for NH = 4, web output: +// u[0] u[1] // Returned in u[0] +// u[2] u[3] // Returned in u[1] +// v[3]rev v[2]rev // Returned in u[2] +// v[1]rev v[0]rev // Returned in u[3] +void OVERLOAD reverseLine2(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an +// unqualified bar() call here. Specifically, the u values are stored in the upper half of lds memory (SMALL_HEIGHT GF61 values). +// The v values are stored in the lower half of lds memory (the next SMALL_HEIGHT GF61 values). + + if (G_H > WAVEFRONT) bar(); + +// For NH=4, the lds indices (where to write each incoming u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H +// That means saving to lds using index: me < G_H ? me % G_H + i * G_H : 8*G_H-1 - me % G_H - i * G_H + +#if 1 + local GF61 *ldsOut = lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { *ldsOut = u[i]; } + + lds += me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[i * 2*G_H]; } +#else + local Z61 *ldsOut = (local Z61 *) lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { ldsOut[0] = u[i].x; ldsOut[NH*2*G_H] = u[i].y; } + + local ZF61 *ldsIn = (local T *) lds + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i].x = ldsIn[i * 2*G_H]; u[i].y = ldsIn[NH*2*G_H + i * 2*G_H]; } +#endif +} + +// Undo a reverseLine2 +void OVERLOAD unreverseLine2(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially +// writing to the lds locations that reverseLine2 read from we do not need an initial bar() call here. Also, by reading +// from the lds locations that shufl2 will use (u values in the upper half of lds memory, v values in the lower half of +// lds memory) we can issue a qualified bar() call before calling FFT_HEIGHT2. + +#if 1 + local GF61 *ldsOut = lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i]; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + lds += (me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H; + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, lds += ldsInc) { u[i] = *lds; } +#else + local Z61 *ldsOut = (local T *) lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i].x; ldsOut[NH*2*G_H + i * 2*G_H] = u[i].y; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + local Z61 *ldsIn = (local T *) lds + ((me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } +#endif +} + +#endif diff --git a/src/cl/transpose.cl b/src/cl/transpose.cl index be999393..8a1c9dae 100644 --- a/src/cl/transpose.cl +++ b/src/cl/transpose.cl @@ -2,8 +2,7 @@ #include "base.cl" -// Prototypes -void transposeWords(u32 W, u32 H, local Word2 *lds, global const Word2 *restrict in, global Word2 *restrict out); +#if WordSize <= 4 void transposeWords(u32 W, u32 H, local Word2 *lds, global const Word2 *restrict in, global Word2 *restrict out) { u32 GPW = W / 64, GPH = H / 64; @@ -38,3 +37,51 @@ KERNEL(64) transposeIn(P(Word2) out, CP(Word2) in) { local Word2 lds[4096]; transposeWords(BIG_HEIGHT, WIDTH, lds, in, out); } + +#else + +void transposeWords(u32 W, u32 H, local Word *lds, global const Word2 *restrict in, global Word2 *restrict out) { + u32 GPW = W / 64, GPH = H / 64; + + u32 g = get_group_id(0); + u32 gy = g % GPH; + u32 gx = g / GPH; + gx = (gy + gx) % GPW; + + in += 64 * W * gy + 64 * gx; + out += 64 * gy + 64 * H * gx; + u32 me = get_local_id(0); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + lds[i * 64 + me] = in[i * W + me].x; + } + bar(); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + out[i * H + me].x = lds[me * 64 + i]; + } + bar(); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + lds[i * 64 + me] = in[i * W + me].y; + } + bar(); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + out[i * H + me].y = lds[me * 64 + i]; + } +} + +// from transposed to sequential. +KERNEL(64) transposeOut(P(Word2) out, CP(Word2) in) { + local Word lds[4096]; + transposeWords(WIDTH, BIG_HEIGHT, lds, in, out); +} + +// from sequential to transposed. +KERNEL(64) transposeIn(P(Word2) out, CP(Word2) in) { + local Word lds[4096]; + transposeWords(BIG_HEIGHT, WIDTH, lds, in, out); +} + +#endif diff --git a/src/cl/trig.cl b/src/cl/trig.cl index 69d32a2c..ebdd2af9 100644 --- a/src/cl/trig.cl +++ b/src/cl/trig.cl @@ -4,7 +4,9 @@ #include "math.cl" -double2 reducedCosSin(int k, double cosBase) { +#if FFT_FP64 + +T2 reducedCosSin(int k, double cosBase) { const double S[] = TRIG_SIN; const double C[] = TRIG_COS; @@ -36,12 +38,12 @@ double2 reducedCosSin(int k, double cosBase) { return U2(c, s); } -double2 fancyTrig_N(u32 k) { +T2 fancyTrig_N(u32 k) { return reducedCosSin(k, 0); } // Returns e^(i * tau * k / n), (tau == 2*pi represents a full circle). So k/n is the ratio of a full circle. -double2 slowTrig_N(u32 k, u32 kBound) { +T2 OVERLOAD slowTrig_N(u32 k, u32 kBound) { u32 n = ND; assert(n % 8 == 0); assert(k < kBound); // kBound actually bounds k @@ -61,11 +63,88 @@ double2 slowTrig_N(u32 k, u32 kBound) { assert(k <= n / 8); - double2 r = reducedCosSin(k, 1); + T2 r = reducedCosSin(k, 1); - if (flip) { r = swap(r); } + if (flip) { r = SWAP_XY(r); } if (negateCos) { r.x = -r.x; } if (negate) { r = -r; } return r; } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F2 reducedCosSin(int k, double cosBase) { + const float S[] = TRIG_SIN; + const float C[] = TRIG_COS; + + float x = k * TRIG_SCALE; + float z = x * x; + + float r1 = fma(S[7], z, S[6]); + float r2 = fma(C[7], z, C[6]); + + r1 = fma(r1, z, S[5]); + r2 = fma(r2, z, C[5]); + + r1 = fma(r1, z, S[4]); + r2 = fma(r2, z, C[4]); + + r1 = fma(r1, z, S[3]); + r2 = fma(r2, z, C[3]); + + r1 = fma(r1, z, S[2]); + r2 = fma(r2, z, C[2]); + + r1 = fma(r1, z, S[1]); + r2 = fma(r2, z, C[1]); + + r1 = r1 * x; + float c = fma(r2, z, (float) cosBase); + float s = fma(x, S[0], r1); + + return U2(c, s); +} + +F2 fancyTrig_N(u32 k) { + return reducedCosSin(k, 0); +} + +// Returns e^(i * tau * k / n), (tau == 2*pi represents a full circle). So k/n is the ratio of a full circle. +F2 OVERLOAD slowTrig_N(u32 k, u32 kBound) { + u32 n = ND; + assert(n % 8 == 0); + assert(k < kBound); // kBound actually bounds k + assert(kBound <= 2 * n); // angle <= 2 tau + + if (kBound > n && k >= n) { k -= n; } + assert(k < n); + + bool negate = kBound > n/2 && k >= n/2; + if (negate) { k -= n/2; } + + bool negateCos = kBound > n / 4 && k >= n / 4; + if (negateCos) { k = n/2 - k; } + + bool flip = kBound > n / 8 + 1 && k > n / 8; + if (flip) { k = n / 4 - k; } + + assert(k <= n / 8); + + F2 r = reducedCosSin(k, 1); + + if (flip) { r = SWAP_XY(r); } + if (negateCos) { r.x = -r.x; } + if (negate) { r = -r; } + + return r; +} + +#endif diff --git a/src/cl/weight.cl b/src/cl/weight.cl index 9967c936..62e65d96 100644 --- a/src/cl/weight.cl +++ b/src/cl/weight.cl @@ -3,6 +3,8 @@ #define STEP (NWORDS - (EXP % NWORDS)) // bool isBigWord(u32 extra) { return extra < NWORDS - STEP; } +#if FFT_FP64 + T fweightStep(u32 i) { const T TWO_TO_NTH[8] = { // 2^(k/8) -1 for k in [0..8) @@ -69,3 +71,65 @@ T optionalHalve(T w) { // return w >= 4 ? w / 2 : w; //u.y = bfi(u.y, 0xffefffff, 0); return as_double(u); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F fweightStep(u32 i) { + const F TWO_TO_NTH[8] = { + // 2^(k/8) -1 for k in [0..8) + 0, + 0.090507732665257662, + 0.18920711500272105, + 0.29683955465100964, + 0.41421356237309503, + 0.54221082540794086, + 0.68179283050742912, + 0.83400808640934243, + }; + return TWO_TO_NTH[i * STEP % NW * (8 / NW)]; +} + +F iweightStep(u32 i) { + const F TWO_TO_MINUS_NTH[8] = { + // 2^-(k/8) - 1 for k in [0..8) + 0, + -0.082995956795328771, + -0.15910358474628547, + -0.2288945872960296, + -0.29289321881345248, + -0.35158022267449518, + -0.40539644249863949, + -0.45474613366737116, + }; + return TWO_TO_MINUS_NTH[i * STEP % NW * (8 / NW)]; +} + +F optionalDouble(F iw) { + // In a straightforward implementation, inverse weights are between 0.5 and 1.0. We use inverse weights between 1.0 and 2.0 + // because it allows us to implement this routine with a single OR instruction on the exponent. The original implementation + // where this routine took as input values from 0.25 to 1.0 required both an AND and an OR instruction on the exponent. + // return iw <= 1.0 ? iw * 2 : iw; + assert(iw > 0.5 && iw < 2); + uint u = as_uint(iw); + u |= 0x00800000; + return as_float(u); +} + +F optionalHalve(F w) { // return w >= 4 ? w / 2 : w; + // In a straightforward implementation, weights are between 1.0 and 2.0. We use weights between 2.0 and 4.0 because + // it allows us to implement this routine with a single AND instruction on the exponent. The original implementation + // where this routine took as input values from 1.0 to 4.0 required both an AND and an OR instruction on the exponent. + assert(w >= 2 && w < 8); + uint u = as_uint(w); + u &= 0xFF7FFFFF; + return as_float(u); +} + +#endif diff --git a/src/common.h b/src/common.h index fbf5deda..6daab79a 100644 --- a/src/common.h +++ b/src/common.h @@ -11,6 +11,8 @@ using i32 = int32_t; using u32 = uint32_t; using i64 = int64_t; using u64 = uint64_t; +using i128 = __int128; +using u128 = unsigned __int128; using f128 = __float128; static_assert(sizeof(u8) == 1, "size u8"); @@ -21,6 +23,30 @@ using namespace std; namespace std::filesystem{}; namespace fs = std::filesystem; +#define FFT_FP64 1 +#define FFT_FP32 0 +#define NTT_GF31 0 +#define NTT_GF61 0 +#define NTT_NCW 0 + +// When using multiple primes in an NTT the size of an integer FFT "word" grows such that we need to support words larger than 32-bits +#if (FFT_FP64 && NTT_GF31) | (FFT_FP32 && NTT_GF61) | (NTT_GF31 && NTT_GF61) +typedef i64 Word; +typedef u64 uWord; // Used by unbalance +#else +typedef i32 Word; +typedef u32 uWord; // Used by unbalance +#endif + +using double2 = pair; +using float2 = pair; +using int2 = pair; +using long2 = pair; +using uint = unsigned int; +using uint2 = pair; +using ulong = unsigned long; +using ulong2 = pair; + std::vector split(const string& s, char delim); string hex(u64 x); diff --git a/src/main.cpp b/src/main.cpp index e1b00588..3ded0a9c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -20,8 +20,6 @@ #include // #include from GCC-13 onwards -namespace fs = std::filesystem; - void gpuWorker(GpuCommon shared, Queue *q, i32 instance) { // LogContext context{(instance ? shared.args->tailDir() : ""s) + to_string(instance) + ' '}; // log("Starting worker %d\n", instance); @@ -42,13 +40,13 @@ void gpuWorker(GpuCommon shared, Queue *q, i32 instance) { } -#ifdef __MINGW32__ // for Windows +#if defined(__MINGW32__) || defined(__MINGW64__) // for Windows extern int putenv(const char *); #endif int main(int argc, char **argv) { -#ifdef __MINGW32__ +#if defined(__MINGW32__) || defined(__MINGW64__) putenv("ROC_SIGNAL_POOL_SIZE=32"); #else // Required to work around a ROCm bug when using multiple queues @@ -66,7 +64,7 @@ int main(int argc, char **argv) { fs::current_path(args.dir); } } - + fs::path poolDir; { Args args{true}; @@ -74,24 +72,27 @@ int main(int argc, char **argv) { args.parse(mainLine); poolDir = args.masterDir; } - + initLog("gpuowl-0.log"); log("PRPLL %s starting\n", VERSION); - + Args args; if (!poolDir.empty()) { args.readConfig(poolDir / "config.txt"); } args.readConfig("config.txt"); args.parse(mainLine); args.setDefaults(); - + if (args.maxAlloc) { AllocTrac::setMaxAlloc(args.maxAlloc); } Context context(getDevice(args.device)); - TrigBufCache bufCache{&context}; Signal signal; Background background; - GpuCommon shared{&args, &bufCache, &background}; + GpuCommon shared; + shared.args = &args; + TrigBufCache bufCache{&context}; + shared.bufCache = &bufCache; + shared.background = &background; if (args.doCtune || args.doTune || args.doZtune || args.carryTune) { Queue q(context, args.profile); diff --git a/src/state.cpp b/src/state.cpp index 617a2e20..b860aceb 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -8,29 +8,29 @@ #include #include -static int lowBits(int u, int bits) { return (u << (32 - bits)) >> (32 - bits); } +static i64 lowBits(i64 u, int bits) { return (u << (64 - bits)) >> (64 - bits); } -static u32 unbalance(int w, int nBits, int *carry) { +static uWord unbalance(Word w, int nBits, int *carry) { assert(*carry == 0 || *carry == -1); w += *carry; *carry = 0; if (w < 0) { - w += (1 << nBits); + w += ((Word) 1 << nBits); *carry = -1; } - if (!(0 <= w && w < (1 << nBits))) { log("w=%d, nBits=%d\n", w, nBits); } - assert(0 <= w && w < (1 << nBits)); + if (!(0 <= w && w < ((Word) 1 << nBits))) { log("w=%lX, nBits=%d\n", (long) w, nBits); } + assert(0 <= w && w < ((Word) 1 << nBits)); return w; } -std::vector compactBits(const vector &dataVect, u32 E) { +std::vector compactBits(const vector &dataVect, u32 E) { if (dataVect.empty()) { return {}; } // Indicating all zero std::vector out; out.reserve((E - 1) / 32 + 1); u32 N = dataVect.size(); - const int *data = dataVect.data(); + const Word *data = dataVect.data(); int carry = 0; u32 outWord = 0; @@ -38,20 +38,26 @@ std::vector compactBits(const vector &dataVect, u32 E) { for (u32 p = 0; p < N; ++p) { int nBits = bitlen(N, E, p); - u32 w = unbalance(data[p], nBits, &carry); + uWord w = unbalance(data[p], nBits, &carry); assert(nBits > 0); - assert(w < (1u << nBits)); - + assert(w < ((uWord) 1 << nBits)); assert(haveBits < 32); - int topBits = 32 - haveBits; - outWord |= w << haveBits; - if (nBits >= topBits) { - out.push_back(outWord); - outWord = w >> topBits; - haveBits = nBits - topBits; - } else { - haveBits += nBits; + + while (nBits) { + int topBits = 32 - haveBits; + outWord |= w << haveBits; + if (nBits >= topBits) { + w >>= topBits; + nBits -= topBits; + out.push_back(outWord); + outWord = 0; + haveBits = 0; + } else { + haveBits += nBits; + w >>= nBits; + break; + } } } @@ -69,20 +75,20 @@ std::vector compactBits(const vector &dataVect, u32 E) { } struct BitBucket { - u64 bits; + u128 bits; u32 size; BitBucket() : bits(0), size(0) {} void put32(u32 b) { - assert(size <= 32); - bits += (u64(b) << size); + assert(size <= 96); + bits += (u128(b) << size); size += 32; } - int popSigned(u32 n) { + i64 popSigned(u32 n) { assert(size >= n); - int b = lowBits(bits, n); + i64 b = lowBits((i64) bits, n); size -= n; bits >>= n; bits += (b < 0); // carry fixup. @@ -90,22 +96,22 @@ struct BitBucket { } }; -vector expandBits(const vector &compactBits, u32 N, u32 E) { +vector expandBits(const vector &compactBits, u32 N, u32 E) { assert(E % 32 != 0); - std::vector out(N); - int *data = out.data(); + std::vector out(N); + Word *data = out.data(); BitBucket bucket; auto it = compactBits.cbegin(); [[maybe_unused]] auto itEnd = compactBits.cend(); for (u32 p = 0; p < N; ++p) { - u32 len = bitlen(N, E, p); - if (bucket.size < len) { + u32 len = bitlen(N, E, p); + while (bucket.size < len) { assert(it != itEnd); bucket.put32(*it++); } - data[p] = bucket.popSigned(len); + data[p] = (Word) bucket.popSigned(len); } assert(it == itEnd); assert(bucket.size == 32 - E % 32); diff --git a/src/state.h b/src/state.h index 4b32be04..9b37c0fd 100644 --- a/src/state.h +++ b/src/state.h @@ -8,8 +8,8 @@ #include #include -vector compactBits(const vector &dataVect, u32 E); -vector expandBits(const vector &compactBits, u32 N, u32 E); +vector compactBits(const vector &dataVect, u32 E); +vector expandBits(const vector &compactBits, u32 N, u32 E); constexpr u32 step(u32 N, u32 E) { return N - (E % N); } constexpr u32 extra(u32 N, u32 E, u32 k) { return u64(step(N, E)) * k % N; } From 1e633900847f9151fb313c7a00624a32e002a035 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 14 Sep 2025 01:52:32 +0000 Subject: [PATCH 018/115] Fixed typos in GF31 + GF61 weights shift calculations --- src/cl/carry.cl | 4 ++-- src/cl/carryfused.cl | 4 ++-- src/cl/fftp.cl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cl/carry.cl b/src/cl/carry.cl index 35c34ee0..dfb32fb2 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -501,10 +501,10 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; - const u32 m31_bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; - const u32 m61_bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; // Derive the big vs. little flags from the fractional number of bits in each word. // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index a571aef7..9fcb6da6 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -1668,10 +1668,10 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; - const u32 m31_bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; - const u32 m61_bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; // Derive the big vs. little flags from the fractional number of bits in each word. // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). diff --git a/src/cl/fftp.cl b/src/cl/fftp.cl index 9e59db30..eafebfe8 100644 --- a/src/cl/fftp.cl +++ b/src/cl/fftp.cl @@ -414,10 +414,10 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; - const u32 m31_bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; - const u32 m61_bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; // Derive the big vs. little flags from the fractional number of bits in each word. // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). From 38d3ae1f47c3f42a806ad8485b795b459195a1ba Mon Sep 17 00:00:00 2001 From: george Date: Wed, 17 Sep 2025 02:08:33 +0000 Subject: [PATCH 019/115] Improved GF31 and GF61 math --- src/cl/fft16.cl | 3 +- src/cl/fft4.cl | 5 +- src/cl/fft8.cl | 51 ++++---------- src/cl/math.cl | 160 ++++++++++++++++++++++++++----------------- src/cl/tailsquare.cl | 2 +- 5 files changed, 118 insertions(+), 103 deletions(-) diff --git a/src/cl/fft16.cl b/src/cl/fft16.cl index 5603a25c..7cbbb24b 100644 --- a/src/cl/fft16.cl +++ b/src/cl/fft16.cl @@ -177,9 +177,8 @@ void OVERLOAD fft16(F2 *u) { u[15] = cmul(u[15], U2(-C1, S1)); // 7t16 u[10] = mul_t8(u[10]); - u[14] = mul_3t8(u[14]); - u[12] = mul_t4(u[12]); + u[14] = mul_3t8(u[14]); fft8Core(u); fft8Core(u + 8); diff --git a/src/cl/fft4.cl b/src/cl/fft4.cl index 388d4aef..29eaa786 100644 --- a/src/cl/fft4.cl +++ b/src/cl/fft4.cl @@ -142,8 +142,7 @@ void OVERLOAD fft4(F2 *u) { fft4by(u, 0, 1, 4); } void OVERLOAD fft4Core(GF31 *u) { X2(u[0], u[2]); - X2(u[1], u[3]); u[3] = mul_t4(u[3]); - + X2_mul_t4(u[1], u[3]); X2(u[0], u[1]); X2(u[2], u[3]); } @@ -153,7 +152,7 @@ void OVERLOAD fft4by(GF31 *u, u32 base, u32 step, u32 M) { #define A(k) u[(base + step * k) % M] - Z31 x0 = add(A(0).x, A(2).x); //GWBUG: Delay some of the mods (we have three spare bits) + Z31 x0 = add(A(0).x, A(2).x); //GWBUG: Delay some of the mods using 64 bit temps? Z31 x2 = sub(A(0).x, A(2).x); Z31 y0 = add(A(0).y, A(2).y); Z31 y2 = sub(A(0).y, A(2).y); diff --git a/src/cl/fft8.cl b/src/cl/fft8.cl index bf4332d7..27e7f816 100644 --- a/src/cl/fft8.cl +++ b/src/cl/fft8.cl @@ -83,10 +83,10 @@ void OVERLOAD fft8(F2 *u) { #if NTT_GF31 void OVERLOAD fft8Core(GF31 *u) { - X2(u[0], u[4]); //GWBUG: Delay some mods using extra 3 bits of Z61 - X2(u[1], u[5]); u[5] = mul_t8(u[5]); - X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); - X2(u[3], u[7]); u[7] = mul_3t8(u[7]); + X2(u[0], u[4]); + X2_mul_t8(u[1], u[5]); + X2_mul_t4(u[2], u[6]); + X2_mul_3t8(u[3], u[7]); fft4Core(u); fft4Core(u + 4); } @@ -108,27 +108,15 @@ void OVERLOAD fft8(GF31 *u) { #if NTT_GF61 -#if 0 // Working code. Fairly readable. - -// Same as mul_t8, but negation of a.y is delayed -GF61 OVERLOAD mul_t8_special(GF61 a) { return U2(shl(a.y + neg(a.x, 2), 30), shl(a.x + a.y, 30)); } -// Same as neg(a.y), X2_mul_t4(a, b) -void OVERLOAD X2_mul_t4_special(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = sub(b->y, a->y); t.x = sub(t.x, b->x); b->x = add(b->y, t.y); b->y = t.x; } - -void OVERLOAD fft4CoreSpecialU1(GF61 *u) { // u[1].y needs negation - X2(u[0], u[2]); - X2_mul_t4_special(&u[1], &u[3]); // u[1].y = -u[1].y; X2(u[1], u[3]); u[3] = mul_t4(u[3]); - X2(u[0], u[1]); - X2(u[2], u[3]); -} +#if 0 // Working code. void OVERLOAD fft8Core(GF61 *u) { X2(u[0], u[4]); //GWBUG: Delay some mods using extra 3 bits of Z61 - X2(u[1], u[5]); u[5] = mul_t8_special(u[5]); // u[5] = mul_t8(u[5]); But u[5].y needs negation + X2_mul_t8(u[1], u[5]); // X2(u[1], u[5]); u[5] = mul_t8(u[5]); X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); - X2(u[3], u[7]); u[7] = mul_3t8(u[7]); + X2_mul_3t8(u[3], u[7]); // X2(u[3], u[7]); u[7] = mul_3t8(u[7]); fft4Core(u); - fft4CoreSpecialU1(u + 4); + fft4Core(u + 4); } // 4 MUL + 52 ADD @@ -139,38 +127,29 @@ void OVERLOAD fft8(GF61 *u) { SWAP(u[3], u[6]); } -#else // Carefully track the size of numbers to reduce the numberof mod M61 reductions - -// Same as mul_t8, but negation of a.y is delayed and a custom m61_count -GF61 OVERLOAD mul_t8_special(GF61 a, u32 m61_count) { return shl(U2(a.y + neg(a.x, m61_count), a.x + a.y), 30); } -// Same as mul_3t8, but with a custom m61_count -GF61 OVERLOAD mul_3t8_special(GF61 a, u32 m61_count) { return shl(U2(a.x + a.y, a.y + neg(a.x, m61_count)), 30); } -// Same as neg(a.y), X2q_mul_t4(a, b, m61_count) -void OVERLOAD X2q_mul_t4_special(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; a->x = a->x + b->x; a->y = b->y + neg(a->y, m61_count); t.x = t.x + neg(b->x, m61_count); b->x = b->y + t.y; b->y = t.x; } +#else // Carefully track the size of numbers to reduce the number of mod M61 reductions void OVERLOAD fft4CoreSpecial1(GF61 *u) { // Starts with u[0,1,2,3] having maximum values of (2,2,3,2)*M61+epsilon. X2q(&u[0], &u[2], 4); // X2(u[0], u[2]); No reductions mod M61. u[0,2] max value is 5,6*M61+epsilon. X2q_mul_t4(&u[1], &u[3], 3); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 5,4*M61+epsilon. - u[1] = mod(u[1]); u[2] = mod(u[2]); // Reduce the worst offenders. u[0,1,2,3] have maximum values of (5,1,1,4)*M61+epsilon. + u[1] = modM61(u[1]); u[2] = modM61(u[2]); // Reduce the worst offenders. u[0,1,2,3] have maximum values of (5,1,1,4)*M61+epsilon. X2s(&u[0], &u[1], 2); // u[0,1] max value before reduction is 6,7*M61+epsilon X2s(&u[2], &u[3], 5); // u[2,3] max value before reduction is 5,6*M61+epsilon } void OVERLOAD fft4CoreSpecial2(GF61 *u) { // Like above, u[1].y needs negation. Starts with u[0,1,2,3] having maximum values of (3,1,2,1)*M61+epsilon. X2q(&u[0], &u[2], 3); // u[0,2] max value is 5,6*M61+epsilon. - X2q_mul_t4_special(&u[1], &u[3], 2); // u[1].y = -u[1].y; X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 3,2*M61+epsilon. - u[0] = mod(u[0]); u[2] = mod(u[2]); // Reduce the worst offenders u[0,1,2,3] have maximum values of (1,3,1,2)*M61+epsilon. + X2q_mul_t4(&u[1], &u[3], 2); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 3,2*M61+epsilon. + u[0] = modM61(u[0]); u[2] = modM61(u[2]); // Reduce the worst offenders u[0,1,2,3] have maximum values of (1,3,1,2)*M61+epsilon. X2s(&u[0], &u[1], 4); // u[0,1] max value before reduction is 4,5*M61+epsilon X2s(&u[2], &u[3], 3); // u[2,3] max value before reduction is 3,4*M61+epsilon } void OVERLOAD fft8Core(GF61 *u) { // Starts with all u[i] having maximum values of M61+epsilon. X2q(&u[0], &u[4], 2); // X2(u[0], u[4]); No reductions mod M61. u[0,4] max value is 2,3*M61+epsilon. - X2q(&u[1], &u[5], 2); // X2(u[1], u[5]); u[1,5] max value is 2,3*M61+epsilon. - u[5] = mul_t8_special(u[5], 4); // u[5] = mul_t8(u[5]); u[5].y needs neg. u[5] max value is 1*M61+epsilon. - X2q_mul_t4(&u[2], &u[6], 2); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); u[2,6] max value is 3,2*M61+epsilon. - X2q(&u[3], &u[7], 2); // X2(u[3], u[7]); u[3,7] max value is 2,3*M61+epsilon. - u[7] = mul_3t8_special(u[7], 4); // u[7] = mul_3t8(u[7]); u[7] max value is 1*M61+epsilon. + X2q_mul_t8(&u[1], &u[5], 2); // X2(u[1], u[5]); u[5] = mul_t8(u[5]); u[1,5] max value is 2,1*M61+epsilon. + X2q_mul_t4(&u[2], &u[6], 2); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); u[2,6] max value is 3,2*M61+epsilon. + X2q_mul_3t8(&u[3], &u[7], 2); // X2(u[3], u[7]); u[7] = mul_3t8(u[7]); u[3,7] max value is 2,1*M61+epsilon. fft4CoreSpecial1(u); fft4CoreSpecial2(u + 4); } diff --git a/src/cl/math.cl b/src/cl/math.cl index d112a180..203063d9 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -33,7 +33,6 @@ u32 i96_lo32(i96 val) { return val.a.lo32; } #define X2_mul_t4(a, b) X2_mul_t4_internal(&(a), &(b)) // X2(a, b), b = mul_t4(b) #define X2_mul_t8(a, b) X2_mul_t8_internal(&(a), &(b)) // X2(a, b), b = mul_t8(b) #define X2_mul_3t8(a, b) X2_mul_3t8_internal(&(a), &(b)) // X2(a, b), b = mul_3t8(b) -#define X2_conja(a, b) X2_conja_internal(&(a), &(b)) // X2(a, b), a = conjugate(a) // NOT USED #define X2_conjb(a, b) X2_conjb_internal(&(a), &(b)) // X2(a, b), b = conjugate(b) #define SWAP(a, b) SWAP_internal(&(a), &(b)) // a = b, b = a #define SWAP_XY(a) U2((a).y, (a).x) // Swap real and imaginary components of a @@ -126,9 +125,6 @@ void OVERLOAD X2_mul_t4_internal(T2 *a, T2 *b) { T2 t = *a; *a = *a + *b; t.x = // Same as X2(a, conjugate(b)) void OVERLOAD X2conjb_internal(T2 *a, T2 *b) { T2 t = *a; a->x = a->x + b->x; a->y = a->y - b->y; b->x = t.x - b->x; b->y = t.y + b->y; } -// Same as X2(a, b), a = conjugate(a) -void OVERLOAD X2_conja_internal(T2 *a, T2 *b) { T2 t = *a; a->x = a->x + b->x; a->y = - (a->y + b->y); *b = t - *b; } - // Same as X2(a, b), b = conjugate(b) void OVERLOAD X2_conjb_internal(T2 *a, T2 *b) { T2 t = *a; *a = t + *b; b->x = t.x - b->x; b->y = b->y - t.y; } @@ -160,7 +156,7 @@ T2 foo(T2 a) { return foo2(a, a); } F2 OVERLOAD conjugate(F2 a) { return U2(a.x, -a.y); } // Multiply by 2 without using floating point instructions. This is a little sloppy as an input of zero returns 2^-126. -F OVERLOAD mul2(F a) { return a + a; } //{ int tmp = as_int(a); tmp += 0x00800000; /* Bump exponent by 1 */ return (as_float(tmp)); } +F OVERLOAD mul2(F a) { return a + a; } //{ int tmp = as_int(a); tmp += 0x00800000; /* Bump exponent by 1 */ return (as_float(tmp)); } F2 OVERLOAD mul2(F2 a) { return U2(mul2(a.x), mul2(a.y)); } // Multiply by -2 without using floating point instructions. This is a little sloppy as an input of zero returns -2^-126. @@ -243,9 +239,6 @@ void OVERLOAD X2_mul_t4_internal(F2 *a, F2 *b) { F2 t = *a; *a = *a + *b; t.x = // Same as X2(a, conjugate(b)) void OVERLOAD X2conjb_internal(F2 *a, F2 *b) { F2 t = *a; a->x = a->x + b->x; a->y = a->y - b->y; b->x = t.x - b->x; b->y = t.y + b->y; } -// Same as X2(a, b), a = conjugate(a) -void OVERLOAD X2_conja_internal(F2 *a, F2 *b) { F2 t = *a; a->x = a->x + b->x; a->y = - (a->y + b->y); *b = t - *b; } - // Same as X2(a, b), b = conjugate(b) void OVERLOAD X2_conjb_internal(F2 *a, F2 *b) { F2 t = *a; *a = t + *b; b->x = t.x - b->x; b->y = b->y - t.y; } @@ -363,10 +356,7 @@ void OVERLOAD X2_mul_t4_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(*a, * void OVERLOAD X2_mul_t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_t8(*b); } // Same as X2(a, b), b = mul_3t8(b) -void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_3t8(*b); } //GWBUG: can we do better (elim a negate)? - -// Same as X2(a, b), a = conjugate(a) -void OVERLOAD X2_conja_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } +void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_3t8(*b); } //GWBUG: can we do better (elim a negate)? // Same as X2(a, b), b = conjugate(b) void OVERLOAD X2_conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } @@ -381,35 +371,63 @@ GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } -#elif 1 // This version is a little sloppy. Returns values in 0..M31 range //GWBUG (could this handle M31+1 too> neg() is hard. If so made_Z31(i64) is faster +#elif 1 // This version is a little sloppy. Returns values in 0..M31 range //GWBUG (could this handle M31+1 too> neg() is hard. If so made_Z31(i64) is faster -// Internal routine to return value in 0..M31 range -Z31 OVERLOAD mod(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return M31+1 +// Internal routines to return value in 0..M31 range +Z31 OVERLOAD modM31(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(i32 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +Z31 OVERLOAD modM31(u64 a) { // a must be less than 0xFFFFFFFF00000000 + u32 alo = a & M31; + u32 amid = (a >> 31) & M31; + u32 ahi = a >> 62; + return modM31(ahi + amid + alo); // 32-bit overflow does not occur due to restrictions on input +} +Z31 OVERLOAD modM31(i64 a) { // abs(a) must be less than 0x7FFFFFFF80000000 + u32 alo = a & M31; + u32 amid = ((u64) a >> 31) & M31; // Unsigned shift might be faster than signed shift + u32 ahi = a >> 62; // Sign extend the top bits + return modM31(ahi + amid + alo); // This is where caller must assure a 32-bit overflow does not occur +} +Z31 OVERLOAD modM31q(u64 a) { // Quick version, a < 2^62 + u32 alo = a & M31; + u32 ahi = a >> 31; + return modM31(ahi + alo); +} +#if 0 // GWBUG - which is faster? +Z31 OVERLOAD modM31q(i64 a) { // Quick version, abs(a) must be 61 bits + u32 alo = a & M31; + i32 ahi = a >> 31; // Sign extend the top bits + if (ahi < 0) ahi = ahi + M31; + return modM31((u32) ahi + alo); +} +#else +Z31 OVERLOAD modM31q(i64 a) { return modM31(a); } // Quick version, abs(a) must be 61 bits +#endif -Z31 OVERLOAD neg(Z31 a) { return M31 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +Z31 OVERLOAD neg(Z31 a) { return M31 - a; } // GWBUG: Examine all callers to see if neg call can be avoided GF31 OVERLOAD neg(GF31 a) { return U2(neg(a.x), neg(a.y)); } -Z31 OVERLOAD add(Z31 a, Z31 b) { return mod(a + b); } +Z31 OVERLOAD add(Z31 a, Z31 b) { return modM31(a + b); } GF31 OVERLOAD add(GF31 a, GF31 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } -Z31 OVERLOAD sub(Z31 a, Z31 b) { return mod(a + neg(b)); } +Z31 OVERLOAD sub(Z31 a, Z31 b) { i32 t = a - b; return (t & M31) + (t >> 31); } GF31 OVERLOAD sub(GF31 a, GF31 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } -Z31 OVERLOAD make_Z31(i32 a) { return (Z31) (a < 0 ? a + M31 : a); } // Handles signed values of a -Z31 OVERLOAD make_Z31(u32 a) { return (Z31) (a); } // a must be in range of 0 .. M31-1 -Z31 OVERLOAD make_Z31(i64 a) { if (a < 0) a += (((i64) M31 << 31) + M31); return add((Z31) (a & M31), (Z31) (a >> 31)); } // Handles 62-bit a values +Z31 OVERLOAD make_Z31(i32 a) { return (Z31) (a < 0 ? a + M31 : a); } // Handles signed values of a +Z31 OVERLOAD make_Z31(u32 a) { return (Z31) (a); } // a must be in range of 0 .. M31-1 +Z31 OVERLOAD make_Z31(i64 a) { return modM31q(a); } // Handles range -2^61..2^61 u32 get_Z31(Z31 a) { return a == M31 ? 0 : a; } // Get value in range 0 to M31-1 i32 get_balanced_Z31(Z31 a) { return (a & 0xC0000000) ? (i32) a - M31 : (i32) a; } // Get balanced value in range -M31/2 to M31/2 // Assumes k reduced mod 31. -Z31 OVERLOAD shl(Z31 a, u32 k) { return ((a << k) + (a >> (31 - k))) & M31; } -GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } -Z31 OVERLOAD shr(Z31 a, u32 k) { return ((a >> k) + (a << (31 - k))) & M31; } +Z31 OVERLOAD shr(Z31 a, u32 k) { return (a >> k) + ((a << (31 - k)) & M31); } GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } +Z31 OVERLOAD shl(Z31 a, u32 k) { return shr(a, 31 - k); } +GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } -//Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } //GWBUG. is M31 * M31 a problem???? I think so! needs double mod -Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return mod(add((Z31) (t & M31), (Z31) (t >> 31))); } //Fixes the M31 * M31 problem +//Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } //GWBUG. is M31 * M31 a problem???? I think so! needs double mod +Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return modM31(add((Z31) (t & M31), (Z31) (t >> 31))); } //Fixes the M31 * M31 problem Z31 OVERLOAD fma(Z31 a, Z31 b, Z31 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? @@ -421,19 +439,46 @@ GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } GF31 OVERLOAD conjugate(GF31 a) { return U2(a.x, neg(a.y)); } // Complex square. input, output 31 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). -GF31 OVERLOAD csq(GF31 a) { return U2(mul(add(a.x, a.y), sub(a.x, a.y)), mul2(mul(a.x, a.y))); } //GWBUG: Probably faster to double a.y and have a mul that takes non-normalized inputs +GF31 OVERLOAD csq(GF31 a) { + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); +} // a^2 + c -GF31 OVERLOAD csqa(GF31 a, GF31 c) { return add(csq(a), c); } // GWBUG: inline csq so we only "mod" after adding c?? Find a way to use fma instructions +GF31 OVERLOAD csqa(GF31 a, GF31 c) { + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r + c.x), modM31(i + c.y)); +} + +// a^2 - c +GF31 OVERLOAD csq_sub(GF31 a, GF31 c) { + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r + neg(c.x)), modM31((i64) i - c.y)); // GWBUG - check that the compiler generates MAD instructions +} // Complex mul -//GF31 OVERLOAD cmul(GF31 a, GF31 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} // GWBUG: Is a 3 multiply complex mul faster? See above +#if 1 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. GF31 OVERLOAD cmul(GF31 a, GF31 b) { - Z31 k1 = mul(b.x, add(a.x, a.y)); - Z31 k2 = mul(a.x, sub(b.y, b.x)); - Z31 k3 = mul(a.y, add(b.y, b.x)); - return U2(sub(k1, k3), add(k1, k2)); + u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 + u64 k2 = a.x * (u64) (b.y + neg(b.x)); + u64 k3 = a.y * (u64) (b.y + b.x); + i64 k1k3 = k1 - k3; // signed 63-bit value, absolute value <= 7FFF FFFE 0000 0002 + u64 k1k2 = k1 + k2; // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + return U2(modM31(k1k3), modM31(k1k2)); +} +#else +GF31 OVERLOAD cmul(GF31 a, GF31 b) { + u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 + u64 k2 = a.x * (u64) (b.y + neg(b.x)); + u64 k3 = neg(a.y) * (u64) (b.y + b.x); + u64 k1k3 = k1 + k3; // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + u64 k1k2 = k1 + k2; // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + return U2(modM31(k1k3), modM31(k1k2)); } +#endif GF31 OVERLOAD cfma(GF31 a, GF31 b, GF31 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? @@ -469,10 +514,7 @@ void OVERLOAD X2_mul_t4_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(*a, * void OVERLOAD X2_mul_t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_t8(*b); } // Same as X2(a, b), b = mul_3t8(b) -void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_3t8(*b); } //GWBUG: can we do better (elim a negate)? - -// Same as X2(a, b), a = conjugate(a) -void OVERLOAD X2_conja_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } +void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); t = sub(*b, t); *b = shl(U2(add(t.x, t.y), sub(t.y, t.x)), 15); } // Same as X2(a, b), b = conjugate(b) void OVERLOAD X2_conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } @@ -599,9 +641,6 @@ void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { X2(*a, *b); *b = mul_t8(*b) // Same as X2(a, b), b = mul_3t8(b) void OVERLOAD X2_mul_3t8_internal(GF61 *a, GF61 *b) { X2(*a, *b); *b = mul_3t8(*b); } -// Same as X2(a, b), a = conjugate(a) -void OVERLOAD X2_conja_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } - // Same as X2(a, b), b = conjugate(b) void OVERLOAD X2_conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } @@ -616,8 +655,8 @@ GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } // In function names, "q" stands for quick, "s" stands for slow (i.e. does mod). // These functions are untested with this strict Z61 implementation. Callers need to eliminate all uses of + or - operators. -Z61 OVERLOAD mod(Z61 a) { return a; } -GF61 OVERLOAD mod(GF61 a) { return a; } +Z61 OVERLOAD modM61(Z61 a) { return a; } +GF61 OVERLOAD modM61(GF61 a) { return a; } Z61 OVERLOAD neg(Z61 a, u32 m61_count) { return neg(a); } GF61 OVERLOAD neg(GF61 a, u32 m61_count) { return neg(a); } Z61 OVERLOAD addq(Z61 a, Z61 b) { return add(a, b); } @@ -649,27 +688,27 @@ u64 OVERLOAD get_Z61(Z61 a) { Z61 m = a - M61; return (m & 0x8000000000000000ULL i64 OVERLOAD get_balanced_Z61(Z61 a) { return (hi32(a) & 0xF0000000) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 // Internal routine to bring Z61 value into the range 0..M61+epsilon -Z61 OVERLOAD mod(Z61 a) { return (a & M61) + (a >> 61); } -GF61 OVERLOAD mod(GF61 a) { return U2(mod(a.x), mod(a.y)); } +Z61 OVERLOAD modM61(Z61 a) { return (a & M61) + (a >> 61); } +GF61 OVERLOAD modM61(GF61 a) { return U2(modM61(a.x), modM61(a.y)); } // Internal routine to negate a value by adding the specified number of M61s -- no mod M61 reduction Z61 OVERLOAD neg(Z61 a, u32 m61_count) { return m61_count * M61 - a; } GF61 OVERLOAD neg(GF61 a, u32 m61_count) { return U2(neg(a.x, m61_count), neg(a.y, m61_count)); } -Z61 OVERLOAD add(Z61 a, Z61 b) { return mod(a + b); } +Z61 OVERLOAD add(Z61 a, Z61 b) { return modM61(a + b); } GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } -Z61 OVERLOAD sub(Z61 a, Z61 b) { return mod(a + neg(b, 2)); } +Z61 OVERLOAD sub(Z61 a, Z61 b) { return modM61(a + neg(b, 2)); } GF61 OVERLOAD sub(GF61 a, GF61 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } - Z61 OVERLOAD neg(Z61 a) { return mod (neg(a, 2)); } // GWBUG: Examine all callers to see if neg call can be avoided + Z61 OVERLOAD neg(Z61 a) { return modM61(neg(a, 2)); } // GWBUG: Examine all callers to see if neg call can be avoided GF61 OVERLOAD neg(GF61 a) { return U2(neg(a.x), neg(a.y)); } // Assumes k reduced mod 61. Z61 OVERLOAD shr(Z61 a, u32 k) { return (a >> k) + ((a << (61 - k)) & M61); } // Return range 0..M61+2^(61-k), can handle 64-bit inputs but small k is big epsilon GF61 OVERLOAD shr(GF61 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } Z61 OVERLOAD shl(Z61 a, u32 k) { return shr(a, 61 - k); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k yields big epsilon -//Z61 OVERLOAD shl(Z61 a, u32 k) { return mod(a << k) + ((a >> (64 - k)) << 3); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k is big epsilon -//Z61 OVERLOAD shl(Z61 a, u32 k) { return mod((a << k) + ((a >> (64 - k)) << 3)); } // Return range 0..M61+epsilon, input must be M61+epsilon a full 62-bit value can overflow +//Z61 OVERLOAD shl(Z61 a, u32 k) { return modM61(a << k) + ((a >> (64 - k)) << 3); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k is big epsilon +//Z61 OVERLOAD shl(Z61 a, u32 k) { return modM61((a << k) + ((a >> (64 - k)) << 3)); } // Return range 0..M61+epsilon, input must be M61+epsilon a full 62-bit value can overflow GF61 OVERLOAD shl(GF61 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } ulong2 wideMul(u64 ab, u64 cd) { @@ -684,9 +723,9 @@ Z61 OVERLOAD weakMul(Z61 a, Z61 b) { // a*b must fit in 125 b return lo61 + hi61; } -Z61 OVERLOAD mul(Z61 a, Z61 b) { return mod(weakMul(a, b)); } +Z61 OVERLOAD mul(Z61 a, Z61 b) { return modM61(weakMul(a, b)); } -Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return mod(weakMul(a, b) + c); } // GWBUG: Can we do better? +Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return modM61(weakMul(a, b) + c); } // GWBUG: Can we do better? // Multiply by 2 Z61 OVERLOAD mul2(Z61 a) { return add(a, a); } @@ -696,10 +735,10 @@ GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } // Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). -GF61 OVERLOAD csq(GF61 a) { return U2(mul(a.x + a.y, mod(a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } +GF61 OVERLOAD csq(GF61 a) { return U2(mul(a.x + a.y, modM61(a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } // a^2 + c -GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(mod(weakMul(a.x + a.y, mod(a.x + neg(a.y, 2))) + c.x), mod(weakMul(a.x + a.x, a.y) + c.y)); } +GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(modM61(weakMul(a.x + a.y, modM61(a.x + neg(a.y, 2))) + c.x), modM61(weakMul(a.x + a.x, a.y) + c.y)); } // Complex mul //GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} @@ -707,7 +746,7 @@ GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 2 extra bits in u64 Z61 k1 = weakMul(b.x, a.x + a.y); // 61+e * 62+e bits = 123+e mult = 62+e bit result Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2)); // 61+e * 63+e bits = 63+e bit result Z61 k3 = weakMul(neg(a.y, 2), b.y + b.x); // 62 * 62+e bits = 63+e bit result - return U2(mod(k1 + k3), mod(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values + return U2(modM61(k1 + k3), modM61(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values } GF61 OVERLOAD cfma(GF61 a, GF61 b, GF61 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? @@ -743,14 +782,11 @@ void OVERLOAD X2conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, void OVERLOAD X2_mul_t4_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } // Same as X2(a, b), b = mul_t8(b) -void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = t + neg(*b, 2); *b = mul_t8(*b, 4); } +void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); t = *b + neg(t, 2); *b = shl(U2(t.x + neg(t.y, 4), t.x + t.y), 30); } // Same as X2(a, b), b = mul_3t8(b) void OVERLOAD X2_mul_3t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = t + neg(*b, 2); *b = mul_3t8(*b, 4); } -// Same as X2(a, b), a = conjugate(a) -void OVERLOAD X2_conja_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = neg(add(a->y, b->y)); *b = sub(t, *b); } - // Same as X2(a, b), b = conjugate(b) void OVERLOAD X2_conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } @@ -761,7 +797,7 @@ GF61 OVERLOAD foo2(GF61 a, GF61 b) { a = addsub(a); b = addsub(b); return addsub GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } // The following routines can be used to reduce mod M61 operations. Caller must track how many M61s need to be added to make positive -// values for subtractions. In function names, "q" stands for quick, "s" stands for slow (i.e. does mod). +// values for subtractions. In function names, "q" stands for quick (no modM61), "s" stands for slow (i.e. does modM61). Z61 OVERLOAD addq(Z61 a, Z61 b) { return a + b; } GF61 OVERLOAD addq(GF61 a, GF61 b) { return U2(addq(a.x, b.x), addq(a.y, b.y)); } @@ -769,11 +805,13 @@ GF61 OVERLOAD addq(GF61 a, GF61 b) { return U2(addq(a.x, b.x), addq(a.y, b.y)); Z61 OVERLOAD subq(Z61 a, Z61 b, u32 m61_count) { return a + neg(b, m61_count); } GF61 OVERLOAD subq(GF61 a, GF61 b, u32 m61_count) { return U2(subq(a.x, b.x, m61_count), subq(a.y, b.y, m61_count)); } -Z61 OVERLOAD subs(Z61 a, Z61 b, u32 m61_count) { return mod(a + neg(b, m61_count)); } +Z61 OVERLOAD subs(Z61 a, Z61 b, u32 m61_count) { return modM61(a + neg(b, m61_count)); } GF61 OVERLOAD subs(GF61 a, GF61 b, u32 m61_count) { return U2(subs(a.x, b.x, m61_count), subs(a.y, b.y, m61_count)); } void OVERLOAD X2q(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; *b = t + neg(*b, m61_count); } -void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } +void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } +void OVERLOAD X2q_mul_t8(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t = *b + neg(t, m61_count); *b = shl(U2(t.x + neg(t.y, m61_count * 2), t.x + t.y), 30); } +void OVERLOAD X2q_mul_3t8(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t = t + neg(*b, m61_count); *b = mul_3t8(t, m61_count * 2); } void OVERLOAD X2s(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = add(t, *b); *b = subs(t, *b, m61_count); } void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, m61_count); b->y = subs(b->y, t.y, m61_count); } diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index cb60abc5..08f60886 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -606,7 +606,7 @@ void OVERLOAD onePairSq(GF31* pa, GF31* pb, GF31 t_squared) { GF31 a = *pa, b = *pb; X2conjb(a, b); - GF31 c = sub(csq(a), cmul(csq(b), t_squared)); + GF31 c = csq_sub(a, cmul(csq(b), t_squared)); // a^2 - (b^2 * t_squared) GF31 d = mul2(cmul(a, b)); X2_conjb(c, d); *pa = SWAP_XY(c), *pb = SWAP_XY(d); From 9f4ca8ca2d6dda623782191666cbaf04b8b45a52 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 18 Sep 2025 19:03:47 +0000 Subject: [PATCH 020/115] Minor tweaks and cleanup in carryutil. Some alternate GF61 math implementations that are not faster (at least on TitanV). --- src/cl/carryutil.cl | 36 ++++++++++++++---------------------- src/cl/math.cl | 19 +++++++++++++++++-- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 0a790db3..8f574791 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -50,25 +50,25 @@ void ROUNDOFF_CHECK(double x) { Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); -//GWBUG - is this ever faster? +//GWBUG - is this ever faster? Not on TitanV //i128 x128 = ((i128) (i64) i96_hi64(x) << 32) | i96_lo32(x); //i64 w = ((i64) x128 << (64 - nBits)) >> (64 - nBits); -//x128 -= w; -//*outCarry = x128 >> nBits; +//*outCarry = (i64) (x128 >> nBits) + (w < 0); //return w; // This code is tricky because me must not shift i32 or u32 variables by 32. #if EXP / NWORDS >= 33 //GWBUG Would the EXP / NWORDS == 32 code be just as fast? i64 xhi = i96_hi64(x); i64 w = lowBits(xhi, nBits - 32); - xhi -= w; - *outCarry = xhi >> (nBits - 32); +// xhi -= w; //GWBUG - is (w < 0) version faster? +// *outCarry = xhi >> (nBits - 32); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); return (w << 32) | i96_lo32(x); #elif EXP / NWORDS == 32 i64 xhi = i96_hi64(x); i64 w = ((i64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); // xhi -= w >> 32; -// *outCarry = xhi >> (nBits - 32); //GWBUG - Would adding (w < 0) be faster than subtracting w>>32 from xhi? +// *outCarry = xhi >> (nBits - 32); //GWBUG - Is this ever faster than adding (w < 0)??? *outCarry = (xhi >> (nBits - 32)) + (w < 0); return w; #elif EXP / NWORDS == 31 @@ -142,7 +142,7 @@ Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { Word OVERLOAD carryStepSloppy(i96 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); -// GWBUG Is this faster (or same speed) ???? This code doesn't work on TitanV??? +// GWBUG Is this faster (or same speed) ???? Does it work??? //i128 x128 = ((i128) xhi << 32) | i96_lo32(x); //*outCarry = x128 >> nBits; //return ((u64) x128 << (64 - nBits)) >> (64 - nBits); @@ -530,29 +530,22 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) u64 n61 = get_Z61(u61); -#if INT128_MATH -i128 v = ((i128) n61 << 31) + n31 - n61; //GWBUG - is this better/as good as int96 code? -// -// i96 value = make_i96(n61 >> 1, ((u32) n61 << 31) | n31); // (n61<<31) + n31 -// i96_sub(&value, n61); +#if 1 //GWBUG - is this better/as good as int96 code? TitanV seems at least as good. + i128 v = ((i128) n61 << 31) + n31 - n61; // n61 * M31 + n31 // Convert to balanced representation by subtracting M61*M31 -if ((v >> 64) & 0xF8000000) v = v - (i128) M31 * (i128) M61; -// if (i96_hi32(value) & 0xF8000000) i96_sub(&value, make_i96(0x0FFFFFFF, 0xDFFFFFFF, 0x80000001)); + if ((v >> 64) & 0xF8000000) v = v - (i128) M31 * (i128) M61; // Optionally calculate roundoff error as proximity to M61*M31/2. 27 bits of accuracy should be sufficient. -// u32 roundoff = (u32) abs((i32) i96_hi32(value)); -u32 roundoff = (u32) abs((i32)(v >> 64)); + u32 roundoff = (u32) abs((i32)(v >> 64)); *maxROE = max(*maxROE, roundoff); // Mul by 3 and add carry #if MUL3 -v = v * 3; -// i96_mul(&value, 3); + v = v * 3; #endif -// i96_add(&value, make_i96((u32)(inCarry >> 63), (u64) inCarry)); -v = v + inCarry; -i96 value = make_i96((u64) (v >> 32), (u32) v); + v = v + inCarry; + i96 value = make_i96((u64) (v >> 32), (u32) v); #else @@ -571,7 +564,6 @@ i96 value = make_i96((u64) (v >> 32), (u32) v); i96_mul(&value, 3); #endif i96_add(&value, make_i96((u32)(inCarry >> 63), (u64) inCarry)); - #endif return value; diff --git a/src/cl/math.cl b/src/cl/math.cl index 203063d9..d1cb630d 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -460,7 +460,7 @@ GF31 OVERLOAD csq_sub(GF31 a, GF31 c) { } // Complex mul -#if 1 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. +#if 1 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. GF31 OVERLOAD cmul(GF31 a, GF31 b) { u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 u64 k2 = a.x * (u64) (b.y + neg(b.x)); @@ -700,7 +700,7 @@ GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } Z61 OVERLOAD sub(Z61 a, Z61 b) { return modM61(a + neg(b, 2)); } GF61 OVERLOAD sub(GF61 a, GF61 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } - Z61 OVERLOAD neg(Z61 a) { return modM61(neg(a, 2)); } // GWBUG: Examine all callers to see if neg call can be avoided +Z61 OVERLOAD neg(Z61 a) { return modM61(neg(a, 2)); } // GWBUG: Examine all callers to see if neg call can be avoided GF61 OVERLOAD neg(GF61 a) { return U2(neg(a.x), neg(a.y)); } // Assumes k reduced mod 61. @@ -735,19 +735,34 @@ GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } // Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +#if 1 GF61 OVERLOAD csq(GF61 a) { return U2(mul(a.x + a.y, modM61(a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } +#else +Z61 OVERLOAD modM61(u128 a) { return modM61(((u64) a & M61) + ((u64) (a >> 61) & M61) + (u64) (a >> 122)); } // GWBUG - Have version without second modM61??? returns a 2*M61+epsilon. +GF61 OVERLOAD csq(GF61 a) { return U2(modM61((a.x + a.y) * (u128) (a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } +#endif // a^2 + c GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(modM61(weakMul(a.x + a.y, modM61(a.x + neg(a.y, 2))) + c.x), modM61(weakMul(a.x + a.x, a.y) + c.y)); } // Complex mul //GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} +#if 1 GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 2 extra bits in u64 Z61 k1 = weakMul(b.x, a.x + a.y); // 61+e * 62+e bits = 123+e mult = 62+e bit result Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2)); // 61+e * 63+e bits = 63+e bit result Z61 k3 = weakMul(neg(a.y, 2), b.y + b.x); // 62 * 62+e bits = 63+e bit result return U2(modM61(k1 + k3), modM61(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values } +#else // Slower on TitanV +Z61 OVERLOAD modM61(u128 a) { return modM61(((u64) a & M61) + ((u64) (a >> 61) & M61) + (u64) (a >> 122)); } // GWBUG - Have version without second modM61??? returns a 2*M61+epsilon. +GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 2 extra bits in u64 and u128 + u128 k1 = b.x * (u128) (a.x + a.y); // 61+e * 62+e bits = 123+e mult = 62+e bit result + u128 k2 = a.x * (u128) (b.y + neg(b.x, 2)); // 61+e * 63+e bits = 63+e bit result + u128 k3 = neg(a.y, 2) * (u128) (b.y + b.x); // 62 * 62+e bits = 63+e bit result + return U2(modM61(k1 + k3), modM61(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values +} +#endif GF61 OVERLOAD cfma(GF61 a, GF61 b, GF61 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? From c02646c10a8ac6f5bc7ab9e2cde303f772c4034e Mon Sep 17 00:00:00 2001 From: george Date: Tue, 23 Sep 2025 19:54:40 +0000 Subject: [PATCH 021/115] Allow GPU to return integer results that are not strictly in the big word / little word range. Carryutil sloppy routines may or may not use this feature in the future. --- src/Gpu.cpp | 18 ++++++++++-------- src/state.cpp | 39 +++++++++++++++------------------------ 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index a7c0a7e1..16fb26b6 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1239,17 +1239,21 @@ u64 Gpu::bufResidue(Buffer &buf) { bufSmallOut.read(words, 64); int carry = 0; - for (int i = 0; i < 32; ++i) { carry = (words[i] + carry < 0) ? -1 : 0; } + for (int i = 0; i < 32; ++i) { + u32 len = bitlen(N, E, N - 32 + i); + i64 w = (i64) words[i] + carry; + carry = (int) (w >> len); + } u64 res = 0; int hasBits = 0; for (int k = 0; k < 32 && hasBits < 64; ++k) { u32 len = bitlen(N, E, k); - Word w = words[32 + k] + carry; - carry = (w < 0) ? -1 : 0; - if (w < 0) { w += (1LL << len); } - assert(w >= 0 && w < (1LL << len)); - res |= u64(w) << hasBits; + i64 tmp = (i64) words[32 + k] + carry; + carry = (int) (tmp >> len); + u64 w = tmp - ((i64) carry << len); + assert(w < (1ULL << len)); + res += w << hasBits; hasBits += len; } return res; @@ -1758,13 +1762,11 @@ PRPResult Gpu::isPrimePRP(const Task& task) { bool doStop = (k % blockSize == 0) && (Signal::stopRequested() || (args.iters && k - startK >= args.iters)); bool leadOut = (k % blockSize == 0) || k == persistK || k == kEnd || useLongCarry; -//if (k%10==0) leadOut = true; //GWBUG assert(!doStop || leadOut); if (doStop) { log("Stopping, please wait..\n"); } square(bufData, bufData, leadIn, leadOut, false); -//if(leadOut)printf("k %d, Residue: %lX\n", k, bufResidue(bufData)); //GWBUG leadIn = leadOut; if (k == persistK) { diff --git a/src/state.cpp b/src/state.cpp index b860aceb..6afe3485 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -10,53 +10,44 @@ static i64 lowBits(i64 u, int bits) { return (u << (64 - bits)) >> (64 - bits); } -static uWord unbalance(Word w, int nBits, int *carry) { - assert(*carry == 0 || *carry == -1); - w += *carry; - *carry = 0; - if (w < 0) { - w += ((Word) 1 << nBits); - *carry = -1; - } - if (!(0 <= w && w < ((Word) 1 << nBits))) { log("w=%lX, nBits=%d\n", (long) w, nBits); } - assert(0 <= w && w < ((Word) 1 << nBits)); - return w; -} - std::vector compactBits(const vector &dataVect, u32 E) { if (dataVect.empty()) { return {}; } // Indicating all zero - std::vector out; - out.reserve((E - 1) / 32 + 1); - u32 N = dataVect.size(); const Word *data = dataVect.data(); + std::vector out; + out.reserve((E - 1) / 32 + 1); + int carry = 0; u32 outWord = 0; int haveBits = 0; + // Convert to compact form for (u32 p = 0; p < N; ++p) { int nBits = bitlen(N, E, p); - uWord w = unbalance(data[p], nBits, &carry); - assert(nBits > 0); + + // Be careful adding in the carry -- it could overflow a 32-bit word. Convert value into desired unsigned range. + i64 tmp = (i64) data[p] + carry; + carry = (int) (tmp >> nBits); + u64 w = (u64) (tmp - ((i64) carry << nBits)); assert(w < ((uWord) 1 << nBits)); - assert(haveBits < 32); + assert(haveBits < 32); while (nBits) { - int topBits = 32 - haveBits; + int needBits = 32 - haveBits; outWord |= w << haveBits; - if (nBits >= topBits) { - w >>= topBits; - nBits -= topBits; + if (nBits >= needBits) { + w >>= needBits; + nBits -= needBits; out.push_back(outWord); outWord = 0; haveBits = 0; } else { haveBits += nBits; w >>= nBits; - break; + break; } } } From 70a0276efa196efd38323c57a9acb1c7b1d55f2f Mon Sep 17 00:00:00 2001 From: george Date: Tue, 23 Sep 2025 20:01:38 +0000 Subject: [PATCH 022/115] Wholsale re-organization of carryutil routines. New feature allows for more sloppy carries which gives a tiny performance boost for M31+M61 NTTs. --- src/cl/carryinc.cl | 56 ++--- src/cl/carryutil.cl | 537 +++++++++++++++++++++++++++----------------- src/cl/math.cl | 1 + src/common.h | 2 - 4 files changed, 365 insertions(+), 231 deletions(-) diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index 0b7241f5..67d75565 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -4,11 +4,15 @@ Word2 OVERLOAD carryFinal(Word2 u, iCARRY inCarry, bool b1) { i32 tmpCarry; - u.x = carryStep(u.x + inCarry, &tmpCarry, b1); + u.x = carryStepSignedSloppy(u.x + inCarry, &tmpCarry, b1); u.y += tmpCarry; return u; } +/*******************************************************************************************/ +/* Original FP64 version to start the carry propagation process for a pair of FFT values */ +/*******************************************************************************************/ + #if FFT_FP64 & !COMBO_FFT // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. @@ -23,13 +27,13 @@ Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, bool b1, bool return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i64 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -53,13 +57,13 @@ Word2 OVERLOAD weightAndCarryPair(F2 u, F2 invWeight, iCARRY inCarry, bool b1, b return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(F2 u, F2 invWeight, iCARRY inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i32 midCarry; i32 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i32 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -83,13 +87,13 @@ Word2 OVERLOAD weightAndCarryPair(GF31 u, u32 invWeight1, u32 invWeight2, i64 in return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -113,13 +117,13 @@ Word2 OVERLOAD weightAndCarryPair(GF61 u, u32 invWeight1, u32 invWeight2, i64 in return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(GF61 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -144,14 +148,14 @@ Word2 OVERLOAD weightAndCarryPair(T2 u, GF31 u31, T invWeight1, T invWeight2, u3 return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i64 midCarry; i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -176,14 +180,14 @@ Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, F invWeight1, F invWeight2, return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, i32 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i32 midCarry; i64 tmp1 = weightAndCarryOne(uF2.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i64 tmp2 = weightAndCarryOne(uF2.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -208,14 +212,14 @@ Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF61 u61, F invWeight1, F invWeight2, return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i64 midCarry; i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, inCarry, maxROE); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, midCarry, maxROE); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } @@ -240,18 +244,18 @@ Word2 OVERLOAD weightAndCarryPair(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m3 return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accurate calculation of the first carry is not required. +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, inCarry, maxROE); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, midCarry, maxROE); - Word b = carryStep(tmp2, outCarry, b2); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } #else -error - missing carryinc implementation +error - missing weightAndCarryPair implementation #endif diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 8f574791..a9e41039 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -13,18 +13,60 @@ typedef i32 CFcarry; // Simply use large carry always as the split kernels are slow anyway (and seldomly used normally). typedef i64 CarryABM; -#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) -i32 lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } +/********************************/ +/* Helper routines */ +/********************************/ + +// Return unsigned low bits (number of bits must be between 1 and 31) +#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) +u32 OVERLOAD ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } #else -i32 lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } +u32 OVERLOAD ulowBits(i32 u, u32 bits) { return (((u32) u << (32 - bits)) >> (32 - bits)); } #endif +u32 OVERLOAD ulowBits(u32 u, u32 bits) { return ulowBits((i32) u, bits); } +// Return unsigned low bits (number of bits must be between 1 and 63) +u64 OVERLOAD ulowBits(i64 u, u32 bits) { return (((u64) u << (64 - bits)) >> (64 - bits)); } +u64 OVERLOAD ulowBits(u64 u, u32 bits) { return ulowBits((i64) u, bits); } +// Return unsigned low bits where number of bits is known at compile time (number of bits can be 0 to 32) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) -u32 ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } +u32 OVERLOAD ulowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return __builtin_amdgcn_ubfe(u, 0, bits); } +#else +u32 OVERLOAD ulowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return u & ((1 << bits) - 1); } +#endif +u32 OVERLOAD ulowFixedBits(u32 u, const u32 bits) { return ulowFixedBits((i32) u, bits); } +// Return unsigned low bits where number of bits is known at compile time (number of bits can be 0 to 64) +u64 OVERLOAD ulowFixedBits(i64 u, const u32 bits) { return u & ((1LL << bits) - 1); } +u64 OVERLOAD ulowFixedBits(u64 u, const u32 bits) { return ulowFixedBits((i64) u, bits); } + +// Return signed low bits (number of bits must be between 1 and 31) +#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) +i32 OVERLOAD lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } +#else +i32 OVERLOAD lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } +#endif +i32 OVERLOAD lowBits(u32 u, u32 bits) { return lowBits((i32) u, bits); } +// Return signed low bits (number of bits must be between 1 and 63) +i64 OVERLOAD lowBits(i64 u, u32 bits) { return ((u << (64 - bits)) >> (64 - bits)); } +i64 OVERLOAD lowBits(u64 u, u32 bits) { return lowBits((i64) u, bits); } + +// Return signed low bits where number of bits is known at compile time (number of bits can be 0 to 32) +#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) +i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return __builtin_amdgcn_sbfe(u, 0, bits); } #else -u32 ulowBits(i32 u, u32 bits) { return (((u32) u << (32 - bits)) >> (32 - bits)); } +// This version should generate 2 shifts +i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return (u << (32 - bits)) >> (32 - bits); } +// This version should generate 2 ANDs and one subtract +//i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; if (bits == 1) return -(u & 1); return ulowFixedBits(u, bits - 1) - (u & (1 << bits)); } +i32 OVERLOAD lowFixedBits(u32 u, const u32 bits) { return lowFixedBits((i32) u, bits); } #endif +// Return signed low bits where number of bits is known at compile time (number of bits can be 1 to 63) +//i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u << (64 - bits)) >> (64 - bits)); } +//i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u64) lowFixedBits((i32) ((u64) u >> 32), bits - 32) << 32) | (u32) u; } +i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return (i64) ulowFixedBits(u, bits - 1) - (u & (1LL << (bits - 1))); } +i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64) u, bits); } +// Extract 32 bits from a 64-bit value #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_alignbit) i32 xtract32(i64 x, u32 bits) { return __builtin_amdgcn_alignbit(as_int2(x).y, as_int2(x).x, bits); } #else @@ -34,162 +76,42 @@ i32 xtract32(i64 x, u32 bits) { return x >> bits; } u32 bitlen(bool b) { return EXP / NWORDS + b; } bool test(u32 bits, u32 pos) { return (bits >> pos) & 1; } -#if 0 -// Check for round off errors above a threshold (default is 0.43) -void ROUNDOFF_CHECK(double x) { -#if DEBUG -#ifndef ROUNDOFF_LIMIT -#define ROUNDOFF_LIMIT 0.43 -#endif - float error = fabs(x - rint(x)); - if (error > ROUNDOFF_LIMIT) printf("Roundoff: %g %30.2f\n", error, x); -#endif -} -#endif - -Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - -//GWBUG - is this ever faster? Not on TitanV -//i128 x128 = ((i128) (i64) i96_hi64(x) << 32) | i96_lo32(x); -//i64 w = ((i64) x128 << (64 - nBits)) >> (64 - nBits); -//*outCarry = (i64) (x128 >> nBits) + (w < 0); -//return w; - -// This code is tricky because me must not shift i32 or u32 variables by 32. -#if EXP / NWORDS >= 33 //GWBUG Would the EXP / NWORDS == 32 code be just as fast? - i64 xhi = i96_hi64(x); - i64 w = lowBits(xhi, nBits - 32); -// xhi -= w; //GWBUG - is (w < 0) version faster? -// *outCarry = xhi >> (nBits - 32); - *outCarry = (xhi >> (nBits - 32)) + (w < 0); - return (w << 32) | i96_lo32(x); -#elif EXP / NWORDS == 32 - i64 xhi = i96_hi64(x); - i64 w = ((i64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); -// xhi -= w >> 32; -// *outCarry = xhi >> (nBits - 32); //GWBUG - Is this ever faster than adding (w < 0)??? - *outCarry = (xhi >> (nBits - 32)) + (w < 0); - return w; -#elif EXP / NWORDS == 31 - i64 w = ((i64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); - *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); - return w; -#else - i32 w = lowBits(i96_lo32(x), nBits); - *outCarry = ((i96_hi64(x) << (32 - nBits)) | (i96_lo32(x) >> nBits)) + (w < 0); - return w; -#endif -} +#if FFT_FP64 +// Rounding constant: 3 * 2^51, See https://stackoverflow.com/questions/17035464 +#define RNDVAL (3.0 * (1l << 51)) -Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); -#if EXP / NWORDS >= 33 - i32 xhi = (x >> 32); - i32 w = lowBits(xhi, nBits - 32); - xhi -= w; - *outCarry = xhi >> (nBits - 32); - return (Word) (((u64) w << 32) | (u32)(x)); -#elif EXP / NWORDS == 32 - i32 xhi = (x >> 32); - i64 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); - xhi -= w >> 32; - *outCarry = xhi >> (nBits - 32); - return w; -#elif EXP / NWORDS == 31 - i64 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); - x -= w; - *outCarry = x >> nBits; - return w; +// Convert a double to long efficiently. Double must be in RNDVAL+integer format. +i64 RNDVALdoubleToLong(double d) { + int2 words = as_int2(d); +#if EXP / NWORDS >= 19 + // We extend the range to 52 bits instead of 51 by taking the sign from the negation of bit 51 + words.y ^= 0x00080000u; + words.y = lowBits(words.y, 20); #else - Word w = lowBits((i32) x, nBits); - x -= w; - *outCarry = x >> nBits; - return w; + // Take the sign from bit 50 (i.e. use lower 51 bits). + words.y = lowBits(words.y, 19); #endif + return as_long(words); } -Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); -#if EXP / NWORDS >= 33 - i32 xhi = (x >> 32); - i32 w = lowBits(xhi, nBits - 32); - *outCarry = (xhi >> (nBits - 32)) + (w < 0); - return (Word) (((u64) w << 32) | (u32)(x)); -#elif EXP / NWORDS == 32 - i32 xhi = (x >> 32); - i64 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); - *outCarry = (i32) (xhi >> (nBits - 32)) + (w < 0); - return w; -#elif EXP / NWORDS == 31 - i32 w = (x << (64 - nBits)) >> (64 - nBits); // lowBits(x, nBits); - *outCarry = (i32) (x >> nBits) + (w < 0); - return w; -#else - Word w = lowBits(x, nBits); - *outCarry = xtract32(x, nBits) + (w < 0); - return w; -#endif -} +#elif FFT_FP32 +// Rounding constant: 3 * 2^22 +#define RNDVAL (3.0f * (1 << 22)) -Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - Word w = lowBits(x, nBits); - *outCarry = (x - w) >> nBits; +// Convert a float to int efficiently. Float must be in RNDVAL+integer format. +i32 RNDVALfloatToInt(float d) { + int w = as_int(d); +//#if 0 +// We extend the range to 23 bits instead of 22 by taking the sign from the negation of bit 22 +// w ^= 0x00800000u; +// w = lowBits(words.y, 23); +//#else +// // Take the sign from bit 21 (i.e. use lower 22 bits). + w = lowBits(w, 22); +//#endif return w; } - -Word OVERLOAD carryStepSloppy(i96 x, i64 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - -// GWBUG Is this faster (or same speed) ???? Does it work??? -//i128 x128 = ((i128) xhi << 32) | i96_lo32(x); -//*outCarry = x128 >> nBits; -//return ((u64) x128 << (64 - nBits)) >> (64 - nBits); - -// This code is tricky because me must not shift i32 or u32 variables by 32. -#if EXP / NWORDS >= 33 // nBits is 33 or more - i64 xhi = i96_hi64(x); - *outCarry = xhi >> (nBits - 32); - return (Word) (((u64) ulowBits((i32) xhi, nBits - 32) << 32) | i96_lo32(x)); -#elif EXP / NWORDS == 32 // nBits = 32 or 33 - i64 xhi = i96_hi64(x); - *outCarry = xhi >> (nBits - 32); - u64 xlo = i96_lo64(x); - return (xlo << (64 - nBits)) >> (64 - nBits); // ulowBits(xlo, nBits); -#elif EXP / NWORDS == 31 // nBits = 31 or 32 - *outCarry = (i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16)); - return ((u64) i96_lo64(x) << (64 - nBits)) >> (64 - nBits); // ulowBits(xlo, nBits); -#else // nBits less than 32 - *outCarry = (i96_hi64(x) << (32 - nBits)) | (i96_lo32(x) >> nBits); - return ulowBits(i96_lo32(x), nBits); #endif -} - -Word OVERLOAD carryStepSloppy(i64 x, i64 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - *outCarry = x >> nBits; - return ulowBits(x, nBits); -} - -Word OVERLOAD carryStepSloppy(i64 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - *outCarry = xtract32(x, nBits); - return ulowBits(x, nBits); -} - -Word OVERLOAD carryStepSloppy(i32 x, i32 *outCarry, bool isBigWord) { - u32 nBits = bitlen(isBigWord); - *outCarry = x >> nBits; - return ulowBits(x, nBits); -} - -// Carry propagation from word and carry. -Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { - a.x = carryStep(a.x + *carry, carry, b1); - a.y = carryStep(a.y + *carry, carry, b2); - return a; -} // map abs(carry) to floats, with 2^32 corresponding to 1.0 // So that the maximum CARRY32 abs(carry), 2^31, is mapped to 0.5 (the same as the maximum ROE) @@ -208,30 +130,28 @@ void updateStats(global uint *bufROE, u32 posROE, float roundMax) { } #endif - -#if FFT_FP64 - -// Rounding constant: 3 * 2^51, See https://stackoverflow.com/questions/17035464 -#define RNDVAL (3.0 * (1l << 51)) - -// Convert a double to long efficiently. Double must be in RNDVAL+integer format. -i64 RNDVALdoubleToLong(double d) { - int2 words = as_int2(d); -#if EXP / NWORDS >= 19 - // We extend the range to 52 bits instead of 51 by taking the sign from the negation of bit 51 - words.y ^= 0x00080000u; - words.y = lowBits(words.y, 20); -#else - // Take the sign from bit 50 (i.e. use lower 51 bits). - words.y = lowBits(words.y, 19); +#if 0 +// Check for round off errors above a threshold (default is 0.43) +void ROUNDOFF_CHECK(double x) { +#if DEBUG +#ifndef ROUNDOFF_LIMIT +#define ROUNDOFF_LIMIT 0.43 +#endif + float error = fabs(x - rint(x)); + if (error > ROUNDOFF_LIMIT) printf("Roundoff: %g %30.2f\n", error, x); #endif - return as_long(words); } - #endif + +/***************************************************************************/ +/* From the FFT data, construct a value to normalize and carry propagate */ +/***************************************************************************/ + #if FFT_FP64 & !COMBO_FFT +#define SLOPPY_MAXBPW 173 // Based on 142.4M expo in 7.5M FFT = 18.36 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_result_is_acceptable) { @@ -267,29 +187,13 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r #endif } - /**************************************************************************/ /* Similar to above, but for an FFT based on FP32 */ /**************************************************************************/ #elif FFT_FP32 & !COMBO_FFT -// Rounding constant: 3 * 2^22 -#define RNDVAL (3.0f * (1 << 22)) - -// Convert a float to int efficiently. Float must be in RNDVAL+integer format. -i32 RNDVALfloatToInt(float d) { - int w = as_int(d); -//#if 0 -// We extend the range to 23 bits instead of 22 by taking the sign from the negation of bit 22 -// w ^= 0x00800000u; -// w = lowBits(words.y, 23); -//#else -// // Take the sign from bit 21 (i.e. use lower 22 bits). - w = lowBits(w, 22); -//#endif - return w; -} +#define SLOPPY_MAXBPW 0 // F32 FFTs are not practical // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_result_is_acceptable) { @@ -324,13 +228,14 @@ i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_r #endif } - /**************************************************************************/ /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ #elif NTT_GF31 & !COMBO_FFT +#define SLOPPY_MAXBPW 73 // Based on 140M expo in 16M FFT = 8.34 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { @@ -351,13 +256,14 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { return value + inCarry; } - /**************************************************************************/ /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ #elif NTT_GF61 & !COMBO_FFT +#define SLOPPY_MAXBPW 225 // Based on 198M expo in 8M FFT = 23.6 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { @@ -378,13 +284,14 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { return value + inCarry; } - /**************************************************************************/ /* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ /**************************************************************************/ #elif FFT_FP64 & NTT_GF31 +#define SLOPPY_MAXBPW 327 // Based on 142M expo in 4M FFT = 33.86 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, float* maxROE) { @@ -414,13 +321,14 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, return value; } - /**************************************************************************/ /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ #elif FFT_FP32 & NTT_GF31 +#define SLOPPY_MAXBPW 154 // Based on 138M expo in 8M FFT = 16.45 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, i32 inCarry, float* maxROE) { @@ -432,9 +340,6 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, uF2 = uF2 * F2_invWeight - (float) n31; // This should be close to a multiple of M31 uF2 *= 0.0000000004656612873077392578125f; // Divide by 2^31 //GWBUG - check the generated code! -// i32 nF2 = rint(uF2); // GWBUG - does this round cheaply? Best way to round? -// Rounding constant: 3 * 2^22 -#define RNDVAL (3.0f * (1 << 22)) i32 nF2 = lowBits(as_int(uF2 + RNDVAL), 22); i64 v = ((i64) nF2 << 31) - nF2; // nF2 * M31 @@ -451,13 +356,14 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, return v + inCarry; } - /**************************************************************************/ /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ #elif FFT_FP32 & NTT_GF61 +#define SLOPPY_MAXBPW 309 // Based on 134M expo in 4M FFT = 31.95 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { @@ -472,9 +378,6 @@ BUG - need more than 64 bit integers uF2 = uF2 * F2_invWeight - (float) n61; // This should be close to a multiple of M61 uF2 *= 4.3368086899420177360298112034798e-19f; // Divide by 2^61 //GWBUG - check the generated code! -// i32 nF2 = rint(uF2); // GWBUG - does this round cheaply? Best way to round? -// Rounding constant: 3 * 2^22 -#define RNDVAL (3.0f * (1 << 22)) i32 nF2 = lowBits(as_int(uF2 + RNDVAL), 22); i64 v = ((i64) nF2 << 61) - nF2; // nF2 * M61 @@ -486,6 +389,7 @@ BUG - need more than 64 bit integers #else // The final result must be n61 mod M61. Use FP32 data to calculate this value. +#undef RNDVAL //GWBUG - why are we using doubles? #define RNDVAL (3.0 * (1l << 51)) double uuF2 = (double) uF2 * (double) F2_invWeight - (double) n61; // This should be close to a multiple of M61 uuF2 = uuF2 * 4.3368086899420177360298112034798e-19; // Divide by 2^61 //GWBUG - check the generated code! @@ -510,13 +414,14 @@ volatile double xxF2 = uuF2 + RNDVAL; // Divide by 2^61 return value; } - /**************************************************************************/ /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ #elif NTT_GF31 & NTT_GF61 +#define SLOPPY_MAXBPW 383 // Based on 165M expo in 4M FFT = 39.34 BPW + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, u32* maxROE) { @@ -526,12 +431,12 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 // Use chinese remainder theorem to create a 92-bit result. Loosely copied from Yves Gallot's mersenne2 program. u32 n31 = get_Z31(u31); - u61 = sub(u61, make_Z61(n31)); // u61 - u31 + u61 = subq(u61, make_Z61(n31), 2); // u61 - u31 u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) u64 n61 = get_Z61(u61); #if 1 //GWBUG - is this better/as good as int96 code? TitanV seems at least as good. - i128 v = ((i128) n61 << 31) + n31 - n61; // n61 * M31 + n31 + i128 v = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 // Convert to balanced representation by subtracting M61*M31 if ((v >> 64) & 0xF8000000) v = v - (i128) M31 * (i128) M61; @@ -569,18 +474,244 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 return value; } +#else +error - missing weightAndCarryOne implementation +#endif + + +/************************************************************************/ +/* Split a value + carryIn into a big-or-little word and a carryOut */ +/************************************************************************/ + +Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); + +//GWBUG - is this ever faster? Not on TitanV +//i128 x128 = ((i128) (i64) i96_hi64(x) << 32) | i96_lo32(x); +//i64 w = ((i64) x128 << (64 - nBits)) >> (64 - nBits); +//*outCarry = (i64) (x128 >> nBits) + (w < 0); +//return w; +// This code can be tricky because we must not shift i32 or u32 variables by 32. +#if EXP / NWORDS >= 33 //GWBUG Would the EXP / NWORDS == 32 code be just as fast? + i64 xhi = i96_hi64(x); + i64 w = lowBits(xhi, nBits - 32); +// xhi -= w; //GWBUG - is (w < 0) version faster? +// *outCarry = xhi >> (nBits - 32); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return (w << 32) | i96_lo32(x); +#elif EXP / NWORDS == 32 + i64 xhi = i96_hi64(x); + i64 w = lowBits(i96_lo64(x), nBits); +// xhi -= w >> 32; +// *outCarry = xhi >> (nBits - 32); //GWBUG - Is this ever faster than adding (w < 0)??? + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return w; +#elif EXP / NWORDS == 31 + i64 w = lowBits(i96_lo64(x), nBits); + *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); + return w; #else -error - missing carryUtil implementation + i32 w = lowBits(i96_lo32(x), nBits); + *outCarry = ((i96_hi64(x) << (32 - nBits)) | (i96_lo32(x) >> nBits)) + (w < 0); + return w; #endif +} +Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 33 + i32 xhi = (x >> 32); + i32 whi = lowBits(xhi, nBits - 32); + *outCarry = (xhi - whi) >> (nBits - 32); + return (Word) (((u64) whi << 32) | (u32)(x)); +#elif EXP / NWORDS == 32 + i32 xhi = (x >> 32); + i64 w = lowBits(x, nBits); + xhi -= w >> 32; + *outCarry = xhi >> (nBits - 32); + return w; +#elif EXP / NWORDS == 31 + i64 w = lowBits(x, nBits); + *outCarry = (x - w) >> nBits; + return w; +#else + Word w = lowBits((i32) x, nBits); + *outCarry = (x - w) >> nBits; + return w; +#endif +} +Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 33 + i32 xhi = (x >> 32); + i32 w = lowBits(xhi, nBits - 32); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return (Word) (((u64) w << 32) | (u32)(x)); +#elif EXP / NWORDS == 32 + i32 xhi = (x >> 32); + i64 w = lowBits(x, nBits); + *outCarry = (i32) (xhi >> (nBits - 32)) + (w < 0); + return w; +#elif EXP / NWORDS == 31 + i32 w = lowBits(x, nBits); + *outCarry = (i32) (x >> nBits) + (w < 0); + return w; +#else + Word w = lowBits(x, nBits); + *outCarry = xtract32(x, nBits) + (w < 0); + return w; +#endif +} + +Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + Word w = lowBits(x, nBits); + *outCarry = (x - w) >> nBits; + return w; +} + +/*****************************************************************/ +/* Same as CarryStep but returns a faster unsigned result. */ +/* Used on first word of pair in carryFused. */ +/* CarryFinal will later turn this into a balanced signed value. */ +/*****************************************************************/ + +Word OVERLOAD carryStepUnsignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. +#if EXP / NWORDS >= 32 // nBits is 32 or more + i64 xhi = i96_hi64(x) & ~((1ULL << (bigwordBits - 32)) - 1); + *outCarry = xhi >> (nBits - 32); + return ulowBits(i96_lo64(x), bigwordBits); +#elif EXP / NWORDS == 31 // nBits = 31 or 32 + *outCarry = i96_hi64(x) << (32 - nBits); + return i96_lo32(x); // ulowBits(x, bigwordBits = 32); +#else // nBits less than 32 + u32 w = ulowBits(i96_lo32(x), bigwordBits); + *outCarry = (i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) - w) >> nBits); + return w; +#endif +} + +Word OVERLOAD carryStepUnsignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = x >> nBits; + return ulowBits(x, nBits); +} + +Word OVERLOAD carryStepUnsignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = xtract32(x, nBits); + return ulowBits(x, nBits); +} + +Word OVERLOAD carryStepUnsignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = x >> nBits; + return ulowBits(x, nBits); +} + +/**********************************************************************/ +/* Same as CarryStep but may return a faster big word signed result. */ +/* Used on second word of pair in carryFused when not near max BPW. */ +/* Also used on first word in carryFinal when not near max BPW. */ +/**********************************************************************/ + +Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { +#if EXP > NWORDS / 10 * SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 32 // nBits is 32 or more +// i64 w = lowFixedBits(i96_lo64(x), bigwordBits); +// i64 xhi = ((i64) i96_hi64(x) >> (bigwordBits - 32)) + (w < 0); +// *outCarry = xhi << (bigwordBits - nBits); +// or this: + u64 xlo = i96_lo64(x); + u64 xlo_topbit = xlo & (1ULL << (bigwordBits - 1)); + i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; + i64 xhi = i96_hi64(x) + (xlo_topbit >> 32); + *outCarry = xhi >> (nBits - 32); + return w; +#elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 320 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) + i32 w = i96_lo32(x); // lowBits(x, bigwordBits = 32); + *outCarry = (i96_hi64(x) + (w < 0)) << (32 - nBits); + return w; +#else // nBits less than 32 //GWBUG - is there a faster version? Is this faster than plain old carryStep? + i32 w = lowBits(i96_lo32(x), bigwordBits); + *outCarry = (((i96_hi64(x) << (32 - bigwordBits)) | (i96_lo32(x) >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); + return w; +#endif +#endif +} + +Word OVERLOAD carryStepSignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { +#if EXP > NWORDS / 10 * SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else +// GWBUG - not timed to see if it is faster + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); + u32 w = lowBits(x, bigwordBits); + *outCarry = (((x << (32 - bigwordBits)) | ((u32) x >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); + return w; +#endif +} + +Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { +#if EXP > NWORDS / 10 * SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 32 // nBits is 32 or more + u64 x_topbit = x & (1ULL << (bigwordBits - 1)); + i64 w = ulowFixedBits(x, bigwordBits - 1) - x_topbit; + i32 xhi = (i32)(x >> 32) + (i32)(x_topbit >> 32); + *outCarry = xhi >> (nBits - 32); + return w; +#elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 320 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) + i32 w = x; // lowBits(x, bigwordBits = 32); + *outCarry = ((i32)(x >> 32) + (w < 0)) << (32 - nBits); + return w; +#else // nBits less than 32 //GWBUG - is there a faster version? Is this faster than plain old carryStep? No +// u32 x_topbit = (u32) x & (1 << (bigwordBits - 1)); +// i32 w = ulowFixedBits((u32) x, bigwordBits - 1) - x_topbit; +// *outCarry = (i64)(x + x_topbit) >> nBits; +// return w; + return carryStep(x, outCarry, isBigWord); +#endif +#endif +} + +Word OVERLOAD carryStepSignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { + return carryStep(x, outCarry, isBigWord); +} + + + +// Carry propagation from word and carry. Used by carryB.cl. +Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { + a.x = carryStep(a.x + *carry, carry, b1); + a.y = carryStep(a.y + *carry, carry, b2); + return a; +} /**************************************************************************/ /* Do this last, it depends on weightAndCarryOne defined above */ /**************************************************************************/ -/* Support both 32-bit and 64-bit carries */ +/* Support both 32-bit and 64-bit carries */ // GWBUG - not all NTTs need to support both carries #if WordSize <= 4 #define iCARRY i32 diff --git a/src/cl/math.cl b/src/cl/math.cl index d1cb630d..6e8d7ef1 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -25,6 +25,7 @@ u32 i96_hi32(i96 val) { return val.c.hi32; } u64 i96_lo64(i96 val) { return val.c.lo64; } u64 i96_hi64(i96 val) { return ((u64) val.a.hi32 << 32) + val.a.mid32; } u32 i96_lo32(i96 val) { return val.a.lo32; } +u32 i96_mid32(i96 val) { return val.a.mid32; } // The X2 family of macros and SWAP are #defines because OpenCL does not allow pass by reference. // With NTT support added, we need to turn these macros into overloaded routines. diff --git a/src/common.h b/src/common.h index 6daab79a..b4ff3d65 100644 --- a/src/common.h +++ b/src/common.h @@ -32,10 +32,8 @@ namespace fs = std::filesystem; // When using multiple primes in an NTT the size of an integer FFT "word" grows such that we need to support words larger than 32-bits #if (FFT_FP64 && NTT_GF31) | (FFT_FP32 && NTT_GF61) | (NTT_GF31 && NTT_GF61) typedef i64 Word; -typedef u64 uWord; // Used by unbalance #else typedef i32 Word; -typedef u32 uWord; // Used by unbalance #endif using double2 = pair; From 6f65bb783a8e7b5d0daf2a6b4d449cdf7a1748a4 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 23 Sep 2025 21:11:36 +0000 Subject: [PATCH 023/115] Undid one of the carryStepSignedSloppy changes. --- src/cl/carryutil.cl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index a9e41039..79cb6b6d 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -657,12 +657,14 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { #if EXP > NWORDS / 10 * SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); #else -// GWBUG - not timed to see if it is faster - const u32 bigwordBits = EXP / NWORDS + 1; - u32 nBits = bitlen(isBigWord); - u32 w = lowBits(x, bigwordBits); - *outCarry = (((x << (32 - bigwordBits)) | ((u32) x >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); - return w; + +// GWBUG - not timed to see if it is faster. Highly likely to be slower. +// const u32 bigwordBits = EXP / NWORDS + 1; +// u32 nBits = bitlen(isBigWord); +// u32 w = lowBits(x, bigwordBits); +// *outCarry = (((x << (32 - bigwordBits)) | ((u32) x >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); +// return w; + return carryStep(x, outCarry, isBigWord); #endif } From a71ca88cf01daf8de9e9cb33bdf492995dad5e58 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 24 Sep 2025 19:18:57 +0000 Subject: [PATCH 024/115] Added csqTrig and ccubeTrig for NTTs. Did not prove to be helpful on TitanV. Explored alternate weakMul and csq implementations. --- src/cl/fft-middle.cl | 26 ++++++++++++++++++++++---- src/cl/math.cl | 26 ++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index f4634af3..0cc68def 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -97,7 +97,7 @@ void OVERLOAD fft_MIDDLE(T2 *u) { void OVERLOAD middleMul(T2 *u, u32 s, Trig trig) { assert(s < SMALL_HEIGHT); - if (MIDDLE == 1) { return; } + if (MIDDLE == 1) return; if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. T2 w = trig[s]; // s / BIG_HEIGHT @@ -377,7 +377,7 @@ void OVERLOAD fft_MIDDLE(F2 *u) { void OVERLOAD middleMul(F2 *u, u32 s, TrigFP32 trig) { assert(s < SMALL_HEIGHT); - if (MIDDLE == 1) { return; } + if (MIDDLE == 1) return; if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. F2 w = trig[s]; // s / BIG_HEIGHT @@ -633,14 +633,23 @@ void OVERLOAD fft_MIDDLE(GF31 *u) { void OVERLOAD middleMul(GF31 *u, u32 s, TrigGF31 trig) { assert(s < SMALL_HEIGHT); - if (MIDDLE == 1) { return; } + if (MIDDLE == 1) return; if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. GF31 w = trig[s]; // s / BIG_HEIGHT WADD(1, w); + if (MIDDLE == 2) return; + +#if SHOULD_BE_FASTER + GF31 sq = csqTrig(w); + WADD(2, sq); + GF31 base = ccubeTrig(sq, w); // GWBUG: compute w^4 as csqTriq(sq), w^6 as ccubeTrig(w2, w4), and w^5 and w^7 as cmul_a_by_b_and_conjb + for (u32 k = 3; k < MIDDLE; ++k) { +#else GF31 base = csq(w); for (u32 k = 2; k < MIDDLE; ++k) { +#endif WADD(k, base); base = cmul(base, w); } @@ -737,14 +746,23 @@ void OVERLOAD fft_MIDDLE(GF61 *u) { void OVERLOAD middleMul(GF61 *u, u32 s, TrigGF61 trig) { assert(s < SMALL_HEIGHT); - if (MIDDLE == 1) { return; } + if (MIDDLE == 1) return; if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. GF61 w = trig[s]; // s / BIG_HEIGHT WADD(1, w); + if (MIDDLE == 2) return; + +#if SHOULD_BE_FASTER + GF61 sq = csqTrig(w); + WADD(2, sq); + GF61 base = ccubeTrig(sq, w); // GWBUG: compute w^4 as csqTriq(sq), w^6 as ccubeTrig(w2, w4), and w^5 and w^7 as cmul_a_by_b_and_conjb + for (u32 k = 3; k < MIDDLE; ++k) { +#else GF61 base = csq(w); for (u32 k = 2; k < MIDDLE; ++k) { +#endif WADD(k, base); base = cmul(base, w); } diff --git a/src/cl/math.cl b/src/cl/math.cl index 6e8d7ef1..2de05226 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -493,6 +493,12 @@ void OVERLOAD cmul_a_by_b_and_conjb(GF31 *res1, GF31 *res2, GF31 a, GF31 b) { res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. } +// Square a root of unity complex number +GF31 OVERLOAD csqTrig(GF31 a) { Z31 two_ay = a.y + a.y; return U2(modM31(1 + two_ay * (u64) neg(a.y)), modM31(a.x * (u64) two_ay)); } + +// Cube w, a root of unity complex number, given w^2 and w +GF31 OVERLOAD ccubeTrig(GF31 sq, GF31 w) { Z31 tmp = sq.y + sq.y; return U2(modM31(tmp * (u64) neg(w.y) + w.x), modM31(tmp * (u64) w.x + neg(w.y))); } + // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? @@ -724,6 +730,17 @@ Z61 OVERLOAD weakMul(Z61 a, Z61 b) { // a*b must fit in 125 b return lo61 + hi61; } +Z61 OVERLOAD weakMul128(Z61 a, Z61 b) { // Handles 64-bit inputs. Result is 62+e bits. + ulong2 ab = wideMul(a, b); + u64 lo = ab.x, hi = ab.y; + return (lo & M61) + ((hi << 3) & M61) + ((u32)(lo >> 61) + (u32)(hi >> 58)); +// u128 r = a * (u128) b; +// u32 rhi = r >> 122; +// u64 rmid = (r >> 61) & M61; +// u64 rlo = r & M61; +// return rhi + rmid + rlo; +} + Z61 OVERLOAD mul(Z61 a, Z61 b) { return modM61(weakMul(a, b)); } Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return modM61(weakMul(a, b) + c); } // GWBUG: Can we do better? @@ -738,6 +755,8 @@ GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } // Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). #if 1 GF61 OVERLOAD csq(GF61 a) { return U2(mul(a.x + a.y, modM61(a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } +#elif 1 +GF61 OVERLOAD csq(GF61 a) { return U2(modM61(weakMul128(a.x + a.y, a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } //GWBUG - This version (without modM61s) could be used to produce 62+e bit results #else Z61 OVERLOAD modM61(u128 a) { return modM61(((u64) a & M61) + ((u64) (a >> 61) & M61) + (u64) (a >> 122)); } // GWBUG - Have version without second modM61??? returns a 2*M61+epsilon. GF61 OVERLOAD csq(GF61 a) { return U2(modM61((a.x + a.y) * (u128) (a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } @@ -777,6 +796,13 @@ void OVERLOAD cmul_a_by_b_and_conjb(GF61 *res1, GF61 *res2, GF61 a, GF61 b) { res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. } +// Square a root of unity complex number (the second version may be faster if the compiler optimizes the u128 squaring). +//GF61 OVERLOAD csqTrig(GF61 a) { Z61 two_ay = a.y + a.y; return U2(modM61(1 + weakMul(two_ay, neg(a.y, 2))), mul(a.x, two_ay)); } +GF61 OVERLOAD csqTrig(GF61 a) { Z61 ay_sq = weakMul(a.y, a.y); return U2(modM61(1 + neg(ay_sq + ay_sq, 4)), mul2(weakMul(a.x, a.y))); } + +// Cube w, a root of unity complex number, given w^2 and w +GF61 OVERLOAD ccubeTrig(GF61 sq, GF61 w) { Z61 tmp = sq.y + sq.y; return U2(modM61(weakMul(tmp, neg(w.y, 2)) + w.x), modM61(weakMul(tmp, w.x) + neg(w.y, 2))); } + // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? From f9be4ba8b88e7d55e2a08ef450303ef62cda5a5c Mon Sep 17 00:00:00 2001 From: george Date: Wed, 24 Sep 2025 21:30:40 +0000 Subject: [PATCH 025/115] Eliminate chainmul in middle. It is faster on TitanV. Not sure if we need to expose the MIDDLE_CHAINMUL option to the end user. --- src/TrigBufCache.cpp | 8 +++++-- src/TrigBufCache.h | 4 ++-- src/cl/base.cl | 4 ++++ src/cl/fft-middle.cl | 50 +++++++++++++++++++++++++++++++++----------- 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/src/TrigBufCache.cpp b/src/TrigBufCache.cpp index 734b9487..1e483b72 100644 --- a/src/TrigBufCache.cpp +++ b/src/TrigBufCache.cpp @@ -570,7 +570,9 @@ vector genMiddleTrigGF31(u32 smallH, u32 middle, u32 width) { tab.resize(1); } else { GF31 root1hm = GF31::root_one(smallH * middle); - for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF31(root1hm, k)); } + for (u32 m = 1; m < middle; ++m) { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF31(root1hm, k * m)); } + } GF31 root1mw = GF31::root_one(middle * width); for (u32 k = 0; k < width; ++k) { tab.push_back(root1GF31(root1mw, k)); } GF31 root1wmh = GF31::root_one(width * middle * smallH); @@ -733,7 +735,9 @@ vector genMiddleTrigGF61(u32 smallH, u32 middle, u32 width) { tab.resize(1); } else { GF61 root1hm = GF61::root_one(smallH * middle); - for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF61(root1hm, k)); } + for (u32 m = 1; m < middle; ++m) { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF61(root1hm, k * m)); } + } GF61 root1mw = GF61::root_one(middle * width); for (u32 k = 0; k < width; ++k) { tab.push_back(root1GF61(root1mw, k)); } GF61 root1wmh = GF61::root_one(width * middle * smallH); diff --git a/src/TrigBufCache.h b/src/TrigBufCache.h index 3270f558..c1a6427c 100644 --- a/src/TrigBufCache.h +++ b/src/TrigBufCache.h @@ -78,12 +78,12 @@ ulong2 root1GF61(u32 N, u32 k); // Compute the size of the largest possible trig buffer given width, middle, height (in number of uint2 values) #define SMALLTRIG_GF31_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH)) // See genSmallTrigGF31 #define SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboGF31 -#define MIDDLETRIG_GF31_SIZE(W,M,H) (H + W + H) // See genMiddleTrigGF31 +#define MIDDLETRIG_GF31_SIZE(W,M,H) (H * (M - 1) + W + H) // See genMiddleTrigGF31 // Compute the size of the largest possible trig buffer given width, middle, height (in number of ulong2 values) #define SMALLTRIG_GF61_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH)) // See genSmallTrigGF61 #define SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboGF61 -#define MIDDLETRIG_GF61_SIZE(W,M,H) (H + W + H) // See genMiddleTrigGF61 +#define MIDDLETRIG_GF61_SIZE(W,M,H) (H * (M - 1) + W + H) // See genMiddleTrigGF61 // Convert above sizes to distances (in units of double2) #define SMALLTRIG_FP64_DIST(W,M,H,nH) SMALLTRIG_FP64_SIZE(W,M,H,nH) diff --git a/src/cl/base.cl b/src/cl/base.cl index b92f8fb1..0872ef42 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -107,6 +107,10 @@ G_H "group height" == SMALL_HEIGHT / NH #define TABMUL_CHAIN 0 #endif +#if !defined(MIDDLE_CHAIN) +#define MIDDLE_CHAIN 0 +#endif + #if !defined(UNROLL_W) #if AMDGPU #define UNROLL_W 0 diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index 0cc68def..7674aeca 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -635,9 +635,16 @@ void OVERLOAD middleMul(GF31 *u, u32 s, TrigGF31 trig) { assert(s < SMALL_HEIGHT); if (MIDDLE == 1) return; - if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. - GF31 w = trig[s]; // s / BIG_HEIGHT +#if !MIDDLE_CHAIN // Read all trig values from memory + + for (u32 k = 1; k < MIDDLE; ++k) { + WADD(k, trig[s]); + s += SMALL_HEIGHT; + } + +#else + GF31 w = trig[s]; // s / BIG_HEIGHT WADD(1, w); if (MIDDLE == 2) return; @@ -653,18 +660,24 @@ void OVERLOAD middleMul(GF31 *u, u32 s, TrigGF31 trig) { WADD(k, base); base = cmul(base, w); } + +#endif + } void OVERLOAD middleMul2(GF31 *u, u32 x, u32 y, TrigGF31 trig) { assert(x < WIDTH); assert(y < SMALL_HEIGHT); - trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table - GF31 w = trig[x]; // x / (MIDDLE * WIDTH) + // First trig table comes after the MiddleMul trig table. Second trig table comes after the first MiddleMul2 trig table. + TrigGF31 trig1 = trig + SMALL_HEIGHT * (MIDDLE - 1); + TrigGF31 trig2 = trig1 + WIDTH; + // The first trig table can be shared with MiddleMul trig table if WIDTH = HEIGHT. + if (WIDTH == SMALL_HEIGHT) trig1 = trig; - TrigGF31 trig2 = trig + WIDTH; // Skip over first MiddleMul2 trig table + GF31 w = trig1[x]; // x / (MIDDLE * WIDTH) u32 desired_root = x * y; - GF31 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]); + GF31 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig1[desired_root / SMALL_HEIGHT]); WADD(0, base); for (u32 k = 1; k < MIDDLE; ++k) { @@ -748,9 +761,16 @@ void OVERLOAD middleMul(GF61 *u, u32 s, TrigGF61 trig) { assert(s < SMALL_HEIGHT); if (MIDDLE == 1) return; - if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. - GF61 w = trig[s]; // s / BIG_HEIGHT +#if !MIDDLE_CHAIN // Read all trig values from memory + + for (u32 k = 1; k < MIDDLE; ++k) { + WADD(k, trig[s]); + s += SMALL_HEIGHT; + } + +#else + GF61 w = trig[s]; // s / BIG_HEIGHT WADD(1, w); if (MIDDLE == 2) return; @@ -766,18 +786,24 @@ void OVERLOAD middleMul(GF61 *u, u32 s, TrigGF61 trig) { WADD(k, base); base = cmul(base, w); } + +#endif + } void OVERLOAD middleMul2(GF61 *u, u32 x, u32 y, TrigGF61 trig) { assert(x < WIDTH); assert(y < SMALL_HEIGHT); - trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table - GF61 w = trig[x]; // x / (MIDDLE * WIDTH) + // First trig table comes after the MiddleMul trig table. Second trig table comes after the first MiddleMul2 trig table. + TrigGF61 trig1 = trig + SMALL_HEIGHT * (MIDDLE - 1); + TrigGF61 trig2 = trig1 + WIDTH; + // The first trig table can be shared with MiddleMul trig table if WIDTH = HEIGHT. + if (WIDTH == SMALL_HEIGHT) trig1 = trig; - TrigGF61 trig2 = trig + WIDTH; // Skip over first MiddleMul2 trig table + GF61 w = trig1[x]; // x / (MIDDLE * WIDTH) u32 desired_root = x * y; - GF61 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]); + GF61 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig1[desired_root / SMALL_HEIGHT]); WADD(0, base); for (u32 k = 1; k < MIDDLE; ++k) { From ecd183290c32c1a782d998c5c5cd1959c3b3d9f8 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 25 Sep 2025 16:40:50 +0000 Subject: [PATCH 026/115] Fixed assert --- src/state.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state.cpp b/src/state.cpp index 6afe3485..4ac159d8 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -32,7 +32,7 @@ std::vector compactBits(const vector &dataVect, u32 E) { i64 tmp = (i64) data[p] + carry; carry = (int) (tmp >> nBits); u64 w = (u64) (tmp - ((i64) carry << nBits)); - assert(w < ((uWord) 1 << nBits)); + assert(w < (1ULL << nBits)); assert(haveBits < 32); while (nBits) { From 52dfa84df6897d282f22803bf5ec937edb0124b4 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 26 Sep 2025 00:43:03 +0000 Subject: [PATCH 027/115] Improved GF61 wideMul and csq. Removed unused routines. --- src/cl/math.cl | 150 +++++++++++++++---------------------------------- 1 file changed, 45 insertions(+), 105 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 2de05226..0bbf3da2 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -323,18 +323,6 @@ GF31 OVERLOAD cmul(GF31 a, GF31 b) { return U2(sub(k1, k3), add(k1, k2)); } -GF31 OVERLOAD cfma(GF31 a, GF31 b, GF31 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? - -GF31 OVERLOAD cmul_by_conjugate(GF31 a, GF31 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate - -// Multiply a by b and conjugate(b). This saves 2 multiplies. -void OVERLOAD cmul_a_by_b_and_conjb(GF31 *res1, GF31 *res2, GF31 a, GF31 b) { - Z31 axbx = mul(a.x, b.x); - Z31 aybx = mul(a.y, b.x); - res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? - res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. -} - // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? @@ -481,18 +469,6 @@ GF31 OVERLOAD cmul(GF31 a, GF31 b) { } #endif -GF31 OVERLOAD cfma(GF31 a, GF31 b, GF31 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? - -GF31 OVERLOAD cmul_by_conjugate(GF31 a, GF31 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate - -// Multiply a by b and conjugate(b). This saves 2 multiplies. -void OVERLOAD cmul_a_by_b_and_conjb(GF31 *res1, GF31 *res2, GF31 a, GF31 b) { - Z31 axbx = mul(a.x, b.x); - Z31 aybx = mul(a.y, b.x); - res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? - res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. -} - // Square a root of unity complex number GF31 OVERLOAD csqTrig(GF31 a) { Z31 two_ay = a.y + a.y; return U2(modM31(1 + two_ay * (u64) neg(a.y)), modM31(a.x * (u64) two_ay)); } @@ -612,18 +588,6 @@ GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(k1, k3), add(k1, k2)); } -GF61 OVERLOAD cfma(GF61 a, GF61 b, GF61 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? - -GF61 OVERLOAD cmul_by_conjugate(GF61 a, GF61 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate - -// Multiply a by b and conjugate(b). This saves 2 multiplies. -void OVERLOAD cmul_a_by_b_and_conjb(GF61 *res1, GF61 *res2, GF61 a, GF61 b) { - Z61 axbx = mul(a.x, b.x); - Z61 aybx = mul(a.y, b.x); - res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? - res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. -} - // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? @@ -698,8 +662,8 @@ i64 OVERLOAD get_balanced_Z61(Z61 a) { return (hi32(a) & 0xF0000000) ? (i64) a - Z61 OVERLOAD modM61(Z61 a) { return (a & M61) + (a >> 61); } GF61 OVERLOAD modM61(GF61 a) { return U2(modM61(a.x), modM61(a.y)); } // Internal routine to negate a value by adding the specified number of M61s -- no mod M61 reduction -Z61 OVERLOAD neg(Z61 a, u32 m61_count) { return m61_count * M61 - a; } -GF61 OVERLOAD neg(GF61 a, u32 m61_count) { return U2(neg(a.x, m61_count), neg(a.y, m61_count)); } +Z61 OVERLOAD neg(Z61 a, const u32 m61_count) { return m61_count * M61 - a; } +GF61 OVERLOAD neg(GF61 a, const u32 m61_count) { return U2(neg(a.x, m61_count), neg(a.y, m61_count)); } Z61 OVERLOAD add(Z61 a, Z61 b) { return modM61(a + b); } GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } @@ -723,27 +687,26 @@ ulong2 wideMul(u64 ab, u64 cd) { return U2((u64) r, (u64) (r >> 64)); } -Z61 OVERLOAD weakMul(Z61 a, Z61 b) { // a*b must fit in 125 bits, result will as large as a*b >> 61 +// Returns a * b not modded by M61. Max value of result depends on the m61_counts of the inputs. +// Let n = (a_m61_count - 1) * (b_m61_count - 1). This is the maximum value in the highest 6 bits of a * b. +// If n <= 4 result will be at most (n+1)*M61+epsilon. +// If n > 4 result will be at most 2*M61+epsilon. +Z61 OVERLOAD weakMul(Z61 a, Z61 b, const u32 a_m61_count, const u32 b_m61_count) { ulong2 ab = wideMul(a, b); u64 lo = ab.x, hi = ab.y; - u64 lo61 = lo & M61, hi61 = (hi << 3) + (lo >> 61); - return lo61 + hi61; -} - -Z61 OVERLOAD weakMul128(Z61 a, Z61 b) { // Handles 64-bit inputs. Result is 62+e bits. - ulong2 ab = wideMul(a, b); - u64 lo = ab.x, hi = ab.y; - return (lo & M61) + ((hi << 3) & M61) + ((u32)(lo >> 61) + (u32)(hi >> 58)); -// u128 r = a * (u128) b; -// u32 rhi = r >> 122; -// u64 rmid = (r >> 61) & M61; -// u64 rlo = r & M61; -// return rhi + rmid + rlo; + u64 lo61 = lo & M61; // Max value is M61 + if ((a_m61_count - 1) * (b_m61_count - 1) <= 4) { + hi = (hi << 3) + (lo >> 61); // Max value is (a_m61_count - 1) * (b_m61_count - 1) * M61 + epsilon + return lo61 + hi; // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 1) * M61 + epsilon + } else { + u64 hi61 = ((hi << 3) + (lo >> 61)) & M61; // Max value is M61 + return lo61 + hi61 + (hi >> 58); // Max value is 2*M61 + epsilon + } } -Z61 OVERLOAD mul(Z61 a, Z61 b) { return modM61(weakMul(a, b)); } +Z61 OVERLOAD mul(Z61 a, Z61 b) { return modM61(weakMul(a, b, 2, 2)); } -Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return modM61(weakMul(a, b) + c); } // GWBUG: Can we do better? +Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return modM61(weakMul(a, b, 2, 2) + c); } // GWBUG: Can we do better? // Multiply by 2 Z61 OVERLOAD mul2(Z61 a) { return add(a, a); } @@ -753,65 +716,42 @@ GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } // Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). -#if 1 -GF61 OVERLOAD csq(GF61 a) { return U2(mul(a.x + a.y, modM61(a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } -#elif 1 -GF61 OVERLOAD csq(GF61 a) { return U2(modM61(weakMul128(a.x + a.y, a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } //GWBUG - This version (without modM61s) could be used to produce 62+e bit results -#else -Z61 OVERLOAD modM61(u128 a) { return modM61(((u64) a & M61) + ((u64) (a >> 61) & M61) + (u64) (a >> 122)); } // GWBUG - Have version without second modM61??? returns a 2*M61+epsilon. -GF61 OVERLOAD csq(GF61 a) { return U2(modM61((a.x + a.y) * (u128) (a.x + neg(a.y, 2))), mul2(weakMul(a.x, a.y))); } -#endif +GF61 OVERLOAD csqq(GF61 a, const u32 m61_count) { + if (m61_count > 4) a = modM61(a); + Z61 re = weakMul(a.x + a.y, a.x + neg(a.y, m61_count), 2 * m61_count - 1, 2 * m61_count); + Z61 im = weakMul(a.x + a.x, a.y, 2 * m61_count - 1, m61_count); + return U2(re, im); +} +GF61 OVERLOAD csqs(GF61 a, const u32 m61_count) { return modM61(csqq(a, m61_count)); } +GF61 OVERLOAD csq(GF61 a) { return csqs(a, 2); } // a^2 + c -GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(modM61(weakMul(a.x + a.y, modM61(a.x + neg(a.y, 2))) + c.x), modM61(weakMul(a.x + a.x, a.y) + c.y)); } +GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(modM61(weakMul(a.x + a.y, a.x + neg(a.y, 2), 3, 4) + c.x), modM61(weakMul(a.x + a.x, a.y, 3, 2) + c.y)); } // Complex mul -//GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} -#if 1 -GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 2 extra bits in u64 - Z61 k1 = weakMul(b.x, a.x + a.y); // 61+e * 62+e bits = 123+e mult = 62+e bit result - Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2)); // 61+e * 63+e bits = 63+e bit result - Z61 k3 = weakMul(neg(a.y, 2), b.y + b.x); // 62 * 62+e bits = 63+e bit result - return U2(modM61(k1 + k3), modM61(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values -} -#else // Slower on TitanV -Z61 OVERLOAD modM61(u128 a) { return modM61(((u64) a & M61) + ((u64) (a >> 61) & M61) + (u64) (a >> 122)); } // GWBUG - Have version without second modM61??? returns a 2*M61+epsilon. -GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 2 extra bits in u64 and u128 - u128 k1 = b.x * (u128) (a.x + a.y); // 61+e * 62+e bits = 123+e mult = 62+e bit result - u128 k2 = a.x * (u128) (b.y + neg(b.x, 2)); // 61+e * 63+e bits = 63+e bit result - u128 k3 = neg(a.y, 2) * (u128) (b.y + b.x); // 62 * 62+e bits = 63+e bit result - return U2(modM61(k1 + k3), modM61(k1 + k2)); // k1+k3 and k1+k2 are full 64-bit values -} -#endif - -GF61 OVERLOAD cfma(GF61 a, GF61 b, GF61 c) { return add(cmul(a, b), c); } //GWBUG: Can we do better? - -GF61 OVERLOAD cmul_by_conjugate(GF61 a, GF61 b) { return cmul(a, conjugate(b)); } //GWBUG: We can likely eliminate a negate - -// Multiply a by b and conjugate(b). This saves 2 multiplies. -void OVERLOAD cmul_a_by_b_and_conjb(GF61 *res1, GF61 *res2, GF61 a, GF61 b) { - Z61 axbx = mul(a.x, b.x); - Z61 aybx = mul(a.y, b.x); - res1->x = fma(a.y, neg(b.y), axbx), res1->y = fma(a.x, b.y, aybx); //GWBUG: Can we eliminate neg? - res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, neg(b.y), aybx); //GWBUG: Can we eliminate neg? At least make it a tmp. +GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 3-epsilon extra bits in u64 + Z61 k1 = weakMul(b.x, a.x + a.y, 2, 3); // max value is 3*M61+epsilon + Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2), 2, 3); // max value is 3*M61+epsilon + Z61 k3 = weakMul(a.y, b.y + b.x, 2, 3); // max value is 3*M61+epsilon + return U2(modM61(k1 + neg(k3, 4)), modM61(k1 + k2)); } // Square a root of unity complex number (the second version may be faster if the compiler optimizes the u128 squaring). //GF61 OVERLOAD csqTrig(GF61 a) { Z61 two_ay = a.y + a.y; return U2(modM61(1 + weakMul(two_ay, neg(a.y, 2))), mul(a.x, two_ay)); } -GF61 OVERLOAD csqTrig(GF61 a) { Z61 ay_sq = weakMul(a.y, a.y); return U2(modM61(1 + neg(ay_sq + ay_sq, 4)), mul2(weakMul(a.x, a.y))); } +GF61 OVERLOAD csqTrig(GF61 a) { Z61 ay_sq = weakMul(a.y, a.y, 2, 2); return U2(modM61(1 + neg(ay_sq + ay_sq, 4)), mul2(weakMul(a.x, a.y, 2, 2))); } // Cube w, a root of unity complex number, given w^2 and w -GF61 OVERLOAD ccubeTrig(GF61 sq, GF61 w) { Z61 tmp = sq.y + sq.y; return U2(modM61(weakMul(tmp, neg(w.y, 2)) + w.x), modM61(weakMul(tmp, w.x) + neg(w.y, 2))); } +GF61 OVERLOAD ccubeTrig(GF61 sq, GF61 w) { Z61 tmp = sq.y + sq.y; return U2(modM61(weakMul(tmp, neg(w.y, 2), 3, 3) + w.x), modM61(weakMul(tmp, w.x, 3, 2) + neg(w.y, 2))); } // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? // mul with (-2^30, -2^30). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^30)^2 == 1 (mod M61). -GF61 OVERLOAD mul_t8(GF61 a, u32 m61_count) { return shl(U2(a.y + neg(a.x, m61_count), neg(a.x + a.y, 2 * m61_count - 1)), 30); } +GF61 OVERLOAD mul_t8(GF61 a, const u32 m61_count) { return shl(U2(a.y + neg(a.x, m61_count), neg(a.x + a.y, 2 * m61_count - 1)), 30); } GF61 OVERLOAD mul_t8(GF61 a) { return mul_t8(a, 2); } // mul with (2^30, -2^30). (twiddle of 3*tau/8). -GF61 OVERLOAD mul_3t8(GF61 a, u32 m61_count) { return shl(U2(a.x + a.y, a.y + neg(a.x, m61_count)), 30); } +GF61 OVERLOAD mul_3t8(GF61 a, const u32 m61_count) { return shl(U2(a.x + a.y, a.y + neg(a.x, m61_count)), 30); } GF61 OVERLOAD mul_3t8(GF61 a) { return mul_3t8(a, 2); } // Return a+b and a-b @@ -844,19 +784,19 @@ GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } Z61 OVERLOAD addq(Z61 a, Z61 b) { return a + b; } GF61 OVERLOAD addq(GF61 a, GF61 b) { return U2(addq(a.x, b.x), addq(a.y, b.y)); } -Z61 OVERLOAD subq(Z61 a, Z61 b, u32 m61_count) { return a + neg(b, m61_count); } -GF61 OVERLOAD subq(GF61 a, GF61 b, u32 m61_count) { return U2(subq(a.x, b.x, m61_count), subq(a.y, b.y, m61_count)); } +Z61 OVERLOAD subq(Z61 a, Z61 b, const u32 m61_count) { return a + neg(b, m61_count); } +GF61 OVERLOAD subq(GF61 a, GF61 b, const u32 m61_count) { return U2(subq(a.x, b.x, m61_count), subq(a.y, b.y, m61_count)); } -Z61 OVERLOAD subs(Z61 a, Z61 b, u32 m61_count) { return modM61(a + neg(b, m61_count)); } -GF61 OVERLOAD subs(GF61 a, GF61 b, u32 m61_count) { return U2(subs(a.x, b.x, m61_count), subs(a.y, b.y, m61_count)); } +Z61 OVERLOAD subs(Z61 a, Z61 b, const u32 m61_count) { return modM61(a + neg(b, m61_count)); } +GF61 OVERLOAD subs(GF61 a, GF61 b, const u32 m61_count) { return U2(subs(a.x, b.x, m61_count), subs(a.y, b.y, m61_count)); } -void OVERLOAD X2q(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; *b = t + neg(*b, m61_count); } -void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } -void OVERLOAD X2q_mul_t8(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t = *b + neg(t, m61_count); *b = shl(U2(t.x + neg(t.y, m61_count * 2), t.x + t.y), 30); } -void OVERLOAD X2q_mul_3t8(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = t + *b; t = t + neg(*b, m61_count); *b = mul_3t8(t, m61_count * 2); } +void OVERLOAD X2q(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; *b = t + neg(*b, m61_count); } +void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } +void OVERLOAD X2q_mul_t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t = *b + neg(t, m61_count); *b = shl(U2(t.x + neg(t.y, m61_count * 2), t.x + t.y), 30); } +void OVERLOAD X2q_mul_3t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t = t + neg(*b, m61_count); *b = mul_3t8(t, m61_count * 2); } -void OVERLOAD X2s(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = add(t, *b); *b = subs(t, *b, m61_count); } -void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, u32 m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, m61_count); b->y = subs(b->y, t.y, m61_count); } +void OVERLOAD X2s(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = add(t, *b); *b = subs(t, *b, m61_count); } +void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, m61_count); b->y = subs(b->y, t.y, m61_count); } #endif From 4cafe0eddc3db077e7cbb4bce6838e802676a92d Mon Sep 17 00:00:00 2001 From: george Date: Fri, 26 Sep 2025 00:44:50 +0000 Subject: [PATCH 028/115] Removed some dead code. Tweaked a few comments. --- src/cl/fft-middle.cl | 42 ------------------------------------------ src/cl/fft8.cl | 4 +--- src/cl/tailsquare.cl | 8 ++++---- 3 files changed, 5 insertions(+), 49 deletions(-) diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index 7674aeca..26638cd4 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -684,27 +684,6 @@ void OVERLOAD middleMul2(GF31 *u, u32 x, u32 y, TrigGF31 trig) { base = cmul(base, w); WADD(k, base); } - -#if 0 // Might save a couple of muls with cmul_a_by_b_and_conjb if we can compute "desired_root = x * y + x * SMALL_HEIGHT" with a slightly expanded trig table - GF31 base = slowTrigGF31(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2); - WADD(1, base); - - if (MIDDLE == 2) { - WADD(0, base); - WSUB(0, w); - return; - } - - GF31 basehi, baselo; - cmul_a_by_b_and_conjb(&basehi, &baselo, base, w); - WADD(0, baselo); - WADD(2, basehi); - - for (int i = 3; i < MIDDLE; ++i) { - basehi = cmul(basehi, w); - WADD(i, basehi); - } -#endif } // Do a partial transpose during fftMiddleIn/Out @@ -810,27 +789,6 @@ void OVERLOAD middleMul2(GF61 *u, u32 x, u32 y, TrigGF61 trig) { base = cmul(base, w); WADD(k, base); } - -#if 0 // Might save a couple of muls with cmul_a_by_b_and_conjb if we can compute "desired_root = x * y + x * SMALL_HEIGHT" with a slightly expanded trig table - GF61 base = slowTrigGF61(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2); - WADD(1, base); - - if (MIDDLE == 2) { - WADD(0, base); - WSUB(0, w); - return; - } - - GF61 basehi, baselo; - cmul_a_by_b_and_conjb(&basehi, &baselo, base, w); - WADD(0, baselo); - WADD(2, basehi); - - for (int i = 3; i < MIDDLE; ++i) { - basehi = cmul(basehi, w); - WADD(i, basehi); - } -#endif } // Do a partial transpose during fftMiddleIn/Out diff --git a/src/cl/fft8.cl b/src/cl/fft8.cl index 27e7f816..56e2fc94 100644 --- a/src/cl/fft8.cl +++ b/src/cl/fft8.cl @@ -119,7 +119,6 @@ void OVERLOAD fft8Core(GF61 *u) { fft4Core(u + 4); } -// 4 MUL + 52 ADD void OVERLOAD fft8(GF61 *u) { fft8Core(u); // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo @@ -137,7 +136,7 @@ void OVERLOAD fft4CoreSpecial1(GF61 *u) { // Starts with u[0,1,2,3] havi X2s(&u[2], &u[3], 5); // u[2,3] max value before reduction is 5,6*M61+epsilon } -void OVERLOAD fft4CoreSpecial2(GF61 *u) { // Like above, u[1].y needs negation. Starts with u[0,1,2,3] having maximum values of (3,1,2,1)*M61+epsilon. +void OVERLOAD fft4CoreSpecial2(GF61 *u) { // Similar to above. Starts with u[0,1,2,3] having maximum values of (3,1,2,1)*M61+epsilon. X2q(&u[0], &u[2], 3); // u[0,2] max value is 5,6*M61+epsilon. X2q_mul_t4(&u[1], &u[3], 2); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 3,2*M61+epsilon. u[0] = modM61(u[0]); u[2] = modM61(u[2]); // Reduce the worst offenders u[0,1,2,3] have maximum values of (1,3,1,2)*M61+epsilon. @@ -154,7 +153,6 @@ void OVERLOAD fft8Core(GF61 *u) { // Starts with all u[i] having fft4CoreSpecial2(u + 4); } -// 4 MUL + 52 ADD void OVERLOAD fft8(GF61 *u) { fft8Core(u); // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index 08f60886..b9c81fab 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -901,9 +901,9 @@ void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared) { GF61 a = *pa, b = *pb; X2conjb(a, b); - GF61 c = subq(csq(a), cmul(csq(b), t_squared), 2); - GF61 d = 2 * cmul(a, b); - X2s_conjb(&c, &d, 4); + GF61 c = subq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon + GF61 d = 2 * cmul(a, b); // max d value is 2*M61+epsilon + X2s_conjb(&c, &d, 3); *pa = SWAP_XY(c), *pb = SWAP_XY(d); } @@ -1087,7 +1087,7 @@ void OVERLOAD pairSq2_special(GF61 *u, GF61 base_squared) { u[0] = SWAP_XY(mul2(foo(u[0]))); u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); } else { - onePairSq(&u[i], &u[NH/2+i], base_squared); //GWBUG - why are we only using neg(base squareds) onePairSq could easily compensate for this + onePairSq(&u[i], &u[NH/2+i], base_squared); } GF61 new_base_squared = mul_t4(base_squared); onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); From a0d29d7328f2481b1ff5526ebdf302db49ad92ed Mon Sep 17 00:00:00 2001 From: george Date: Fri, 26 Sep 2025 14:21:40 +0000 Subject: [PATCH 029/115] Fixed TailMul bug with single_wide kernels. The #define for SINGLE_WIDE was not set. --- src/cl/tailsquare.cl | 21 +-------------------- src/cl/tailutil.cl | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index b9c81fab..30dce227 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -4,25 +4,6 @@ #include "trig.cl" #include "fftheight.cl" -// TAIL_TRIGS setting: -// 2 = No memory accesses, trig values computed from scratch. Good for excellent DP GPUs such as Titan V or Radeon VII Pro. -// 1 = Limited memory accesses and some DP computation. Tuned for Radeon VII a GPU with good DP performance. -// 0 = No DP computation. Trig vaules read from memory. Good for GPUs with poor DP performance (a typical consumer grade GPU). -#if !defined(TAIL_TRIGS) -#define TAIL_TRIGS 2 // Default is compute trig values from scratch -#endif - -// TAIL_KERNELS setting: -// 0 = single wide, single kernel -// 1 = single wide, two kernels -// 2 = double wide, single kernel -// 3 = double wide, two kernels -#if !defined(TAIL_KERNELS) -#define TAIL_KERNELS 2 // Default is double-wide tailSquare with two kernels -#endif -#define SINGLE_WIDE TAIL_KERNELS < 2 // Old single-wide tailSquare vs. new double-wide tailSquare -#define SINGLE_KERNEL (TAIL_KERNELS & 1) == 0 // TailSquare uses a single kernel vs. two kernels - #if FFT_FP64 // Handle the final squaring step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. @@ -903,7 +884,7 @@ void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared) { X2conjb(a, b); GF61 c = subq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon GF61 d = 2 * cmul(a, b); // max d value is 2*M61+epsilon - X2s_conjb(&c, &d, 3); + X2s_conjb(&c, &d, 4); *pa = SWAP_XY(c), *pb = SWAP_XY(d); } diff --git a/src/cl/tailutil.cl b/src/cl/tailutil.cl index 8d6d6567..dd1cf41f 100644 --- a/src/cl/tailutil.cl +++ b/src/cl/tailutil.cl @@ -2,6 +2,25 @@ #include "math.cl" +// TAIL_TRIGS setting: +// 2 = No memory accesses, trig values computed from scratch. Good for excellent DP GPUs such as Titan V or Radeon VII Pro. +// 1 = Limited memory accesses and some DP computation. Tuned for Radeon VII a GPU with good DP performance. +// 0 = No DP computation. Trig vaules read from memory. Good for GPUs with poor DP performance (a typical consumer grade GPU). +#if !defined(TAIL_TRIGS) +#define TAIL_TRIGS 2 // Default is compute trig values from scratch +#endif + +// TAIL_KERNELS setting: +// 0 = single wide, single kernel +// 1 = single wide, two kernels +// 2 = double wide, single kernel +// 3 = double wide, two kernels +#if !defined(TAIL_KERNELS) +#define TAIL_KERNELS 2 // Default is double-wide tailSquare with two kernels +#endif +#define SINGLE_WIDE TAIL_KERNELS < 2 // Old single-wide tailSquare vs. new double-wide tailSquare +#define SINGLE_KERNEL (TAIL_KERNELS & 1) == 0 // TailSquare uses a single kernel vs. two kernels + #if FFT_FP64 void OVERLOAD reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { From e64ee76745290b61be12e28858fb12845df4e0b0 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 26 Sep 2025 19:36:34 +0000 Subject: [PATCH 030/115] Delayed negations and mul_t4 from pairSq into onePairSq. --- src/cl/math.cl | 26 ++++++++++++++++++-- src/cl/tailsquare.cl | 56 +++++++++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 0bbf3da2..c25a001a 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -435,10 +435,10 @@ GF31 OVERLOAD csq(GF31 a) { } // a^2 + c -GF31 OVERLOAD csqa(GF31 a, GF31 c) { +GF31 OVERLOAD csq_add(GF31 a, GF31 c) { u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 - return U2(modM31(r + c.x), modM31(i + c.y)); + return U2(modM31(r + c.x), modM31(i + c.y)); // GWBUG - hopefully the 64-bit adds are "free" via MAD instructions } // a^2 - c @@ -448,6 +448,20 @@ GF31 OVERLOAD csq_sub(GF31 a, GF31 c) { return U2(modM31(r + neg(c.x)), modM31((i64) i - c.y)); // GWBUG - check that the compiler generates MAD instructions } +// a^2 + i*c +GF31 OVERLOAD csq_addi(GF31 a, GF31 c) { + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r + neg(c.y)), modM31(i + c.x)); // GWBUG - hopefully the 64-bit adds are "free" via MAD instructions +} + +// a^2 - i*c +GF31 OVERLOAD csq_subi(GF31 a, GF31 c) { + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r + c.y), modM31((i64) i - c.x)); // GWBUG - check that the compiler generates MAD instructions +} + // Complex mul #if 1 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. GF31 OVERLOAD cmul(GF31 a, GF31 b) { @@ -743,6 +757,11 @@ GF61 OVERLOAD csqTrig(GF61 a) { Z61 ay_sq = weakMul(a.y, a.y, 2, 2); return U2(m // Cube w, a root of unity complex number, given w^2 and w GF61 OVERLOAD ccubeTrig(GF61 sq, GF61 w) { Z61 tmp = sq.y + sq.y; return U2(modM61(weakMul(tmp, neg(w.y, 2), 3, 3) + w.x), modM61(weakMul(tmp, w.x, 3, 2) + neg(w.y, 2))); } +// a + i*b +GF61 OVERLOAD addi(GF61 a, GF61 b) { return U2(sub(a.x, b.y), add(a.y, b.x)); } +// a - i*b +GF61 OVERLOAD subi(GF61 a, GF61 b) { return U2(add(a.x, b.y), sub(a.y, b.x)); } + // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? @@ -790,6 +809,9 @@ GF61 OVERLOAD subq(GF61 a, GF61 b, const u32 m61_count) { return U2(subq(a.x, b. Z61 OVERLOAD subs(Z61 a, Z61 b, const u32 m61_count) { return modM61(a + neg(b, m61_count)); } GF61 OVERLOAD subs(GF61 a, GF61 b, const u32 m61_count) { return U2(subs(a.x, b.x, m61_count), subs(a.y, b.y, m61_count)); } +GF61 OVERLOAD addiq(GF61 a, GF61 b, const u32 m61_count) { return U2(subq(a.x, b.y, m61_count), addq(a.y, b.x)); } +GF61 OVERLOAD subiq(GF61 a, GF61 b, const u32 m61_count) { return U2(addq(a.x, b.y), subq(a.y, b.x, m61_count)); } + void OVERLOAD X2q(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; *b = t + neg(*b, m61_count); } void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } void OVERLOAD X2q_mul_t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t = *b + neg(t, m61_count); *b = shl(U2(t.x + neg(t.y, m61_count * 2), t.x + t.y), 30); } diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index 30dce227..90a37203 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -583,12 +583,20 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #if NTT_GF31 -void OVERLOAD onePairSq(GF31* pa, GF31* pb, GF31 t_squared) { +void OVERLOAD onePairSq(GF31* pa, GF31* pb, GF31 t_squared, const u32 t_squared_type) { GF31 a = *pa, b = *pb; + GF31 c, d; X2conjb(a, b); - GF31 c = csq_sub(a, cmul(csq(b), t_squared)); // a^2 - (b^2 * t_squared) - GF31 d = mul2(cmul(a, b)); + if (t_squared_type == 0) // mul t_squared by 1 + c = csq_sub(a, cmul(csq(b), t_squared)); // a^2 - (b^2 * t_squared) + if (t_squared_type == 1) // mul t_squared by i + c = csq_subi(a, cmul(csq(b), t_squared)); // a^2 - i*(b^2 * t_squared) + if (t_squared_type == 2) // mul t_squared by -1 + c = csq_add(a, cmul(csq(b), t_squared)); // a^2 - -1*(b^2 * t_squared) + if (t_squared_type == 3) // mul t_squared by -i + c = csq_addi(a, cmul(csq(b), t_squared)); // a^2 - -i*(b^2 * t_squared) + d = mul2(cmul(a, b)); X2_conjb(c, d); *pa = SWAP_XY(c), *pb = SWAP_XY(d); } @@ -601,18 +609,17 @@ void OVERLOAD pairSq(u32 N, GF31 *u, GF31 *v, GF31 base_squared, bool special) { u[i] = SWAP_XY(mul2(foo(u[i]))); v[i] = SWAP_XY(shl(csq(v[i]), 2)); } else { - onePairSq(&u[i], &v[i], base_squared); + onePairSq(&u[i], &v[i], base_squared, 0); } if (N == NH) { - onePairSq(&u[i+NH/2], &v[i+NH/2], neg(base_squared)); //GWBUG -- can we write a special onepairsq that expects a base_squared that needs negation? + onePairSq(&u[i+NH/2], &v[i+NH/2], base_squared, 2); } - GF31 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); //GWBUG -- or another special onePairSq that expects mul_t4'ed base_squared + onePairSq(&u[i+NH/4], &v[i+NH/4], base_squared, 1); if (N == NH) { - onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], neg(new_base_squared)); //GWBUG -- or another special onePairSq that expects mul_t4'ed and negated base_squared + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], base_squared, 3); } } } @@ -773,10 +780,9 @@ void OVERLOAD pairSq2_special(GF31 *u, GF31 base_squared) { u[0] = SWAP_XY(mul2(foo(u[0]))); u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); } else { - onePairSq(&u[i], &u[NH/2+i], base_squared); //GWBUG - why are we only using neg(base squareds) onePairSq could easily compensate for this + onePairSq(&u[i], &u[NH/2+i], base_squared, 0); } - GF31 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], base_squared, 1); } } @@ -878,12 +884,20 @@ KERNEL(G_H * 2) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { #if NTT_GF61 -void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared) { +void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared, const u32 t_squared_type) { GF61 a = *pa, b = *pb; + GF61 c, d; X2conjb(a, b); - GF61 c = subq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon - GF61 d = 2 * cmul(a, b); // max d value is 2*M61+epsilon + if (t_squared_type == 0) // mul t_squared by 1 + c = subq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon + if (t_squared_type == 1) // mul t_squared by i + c = subiq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon + if (t_squared_type == 2) // mul t_squared by -1 + c = addq(csq(a), cmul(csq(b), t_squared)); // max c value is 3*M61+epsilon + if (t_squared_type == 3) // mul t_squared by -i + c = addiq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon + d = 2 * cmul(a, b); // max d value is 2*M61+epsilon X2s_conjb(&c, &d, 4); *pa = SWAP_XY(c), *pb = SWAP_XY(d); } @@ -896,18 +910,17 @@ void OVERLOAD pairSq(u32 N, GF61 *u, GF61 *v, GF61 base_squared, bool special) { u[i] = SWAP_XY(mul2(foo(u[i]))); v[i] = SWAP_XY(shl(csq(v[i]), 2)); } else { - onePairSq(&u[i], &v[i], base_squared); + onePairSq(&u[i], &v[i], base_squared, 0); } if (N == NH) { - onePairSq(&u[i+NH/2], &v[i+NH/2], neg(base_squared)); //GWBUG -- can we write a special onepairsq that expects a base_squared that needs negation? + onePairSq(&u[i+NH/2], &v[i+NH/2], base_squared, 2); } - GF61 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); //GWBUG -- or another special onePairSq that expects mul_t4'ed base_squared + onePairSq(&u[i+NH/4], &v[i+NH/4], base_squared, 1); if (N == NH) { - onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], neg(new_base_squared)); //GWBUG -- or another special onePairSq that expects mul_t4'ed and negated base_squared + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], base_squared, 3); } } } @@ -1068,10 +1081,9 @@ void OVERLOAD pairSq2_special(GF61 *u, GF61 base_squared) { u[0] = SWAP_XY(mul2(foo(u[0]))); u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); } else { - onePairSq(&u[i], &u[NH/2+i], base_squared); + onePairSq(&u[i], &u[NH/2+i], base_squared, 0); } - GF61 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], base_squared, 1); } } From 101e160188f36defbb45fc412ab65821fed5139e Mon Sep 17 00:00:00 2001 From: george Date: Fri, 26 Sep 2025 20:22:01 +0000 Subject: [PATCH 031/115] Use csqq to save one modular GF61 reduction in onePairSquare --- src/cl/math.cl | 2 +- src/cl/tailmul.cl | 6 +++--- src/cl/tailsquare.cl | 20 ++++++++++---------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index c25a001a..18951301 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -818,7 +818,7 @@ void OVERLOAD X2q_mul_t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; * void OVERLOAD X2q_mul_3t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t = t + neg(*b, m61_count); *b = mul_3t8(t, m61_count * 2); } void OVERLOAD X2s(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = add(t, *b); *b = subs(t, *b, m61_count); } -void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, m61_count); b->y = subs(b->y, t.y, m61_count); } +void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, const u32 a_m61_count, const u32 b_m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, b_m61_count); b->y = subs(b->y, t.y, a_m61_count); } #endif diff --git a/src/cl/tailmul.cl b/src/cl/tailmul.cl index d5421b10..b0e234be 100644 --- a/src/cl/tailmul.cl +++ b/src/cl/tailmul.cl @@ -376,9 +376,9 @@ void OVERLOAD onePairMul(GF61* pa, GF61* pb, GF61* pc, GF61* pd, GF61 t_squared) X2conjb(a, b); X2conjb(c, d); - GF61 e = subq(cmul(a, c), cmul(cmul(b, d), t_squared), 2); - GF61 f = addq(cmul(b, c), cmul(a, d)); - X2s_conjb(&e, &f, 4); + GF61 e = subq(cmul(a, c), cmul(cmul(b, d), t_squared), 2); // Max value is 3*M61+epsilon + GF61 f = addq(cmul(b, c), cmul(a, d)); // Max value is 2*M61+epsilon + X2s_conjb(&e, &f, 4, 3); *pa = SWAP_XY(e), *pb = SWAP_XY(f); } diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index 90a37203..db88d203 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -889,16 +889,16 @@ void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared, const u32 t_squared_ GF61 c, d; X2conjb(a, b); - if (t_squared_type == 0) // mul t_squared by 1 - c = subq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon - if (t_squared_type == 1) // mul t_squared by i - c = subiq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon - if (t_squared_type == 2) // mul t_squared by -1 - c = addq(csq(a), cmul(csq(b), t_squared)); // max c value is 3*M61+epsilon - if (t_squared_type == 3) // mul t_squared by -i - c = addiq(csq(a), cmul(csq(b), t_squared), 2); // max c value is 3*M61+epsilon - d = 2 * cmul(a, b); // max d value is 2*M61+epsilon - X2s_conjb(&c, &d, 4); + if (t_squared_type == 0) // mul t_squared by 1 + c = subq(csqq(a, 2), cmul(csq(b), t_squared), 2); // max c value is 4*M61+epsilon + if (t_squared_type == 1) // mul t_squared by i + c = subiq(csqq(a, 2), cmul(csq(b), t_squared), 2); // max c value is 4*M61+epsilon + if (t_squared_type == 2) // mul t_squared by -1 + c = addq(csqq(a, 2), cmul(csq(b), t_squared)); // max c value is 3*M61+epsilon + if (t_squared_type == 3) // mul t_squared by -i + c = addiq(csqq(a, 2), cmul(csq(b), t_squared), 2); // max c value is 4*M61+epsilon + d = 2 * cmul(a, b); // max d value is 2*M61+epsilon + X2s_conjb(&c, &d, 5, 3); *pa = SWAP_XY(c), *pb = SWAP_XY(d); } From f3c334f469df64d7be7e66f1d1fada19026894b2 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 30 Sep 2025 19:05:00 +0000 Subject: [PATCH 032/115] Fixed optimization bug in GF31*GF61 NTT where BPW was low (less than 23) --- src/cl/carryutil.cl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 79cb6b6d..bbe291f2 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -682,7 +682,10 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { i32 xhi = (i32)(x >> 32) + (i32)(x_topbit >> 32); *outCarry = xhi >> (nBits - 32); return w; -#elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 320 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) +// nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance). For reasons I don't fully understand the sloppy +// case fails if BPW is too low. Probably something to do with a small BPW with sloppy 32-bit values would require CARRY_LONG to work properly. +// Not a major concern as end users should avoid small BPW as there is probably a more efficient NTT that could be used. +#elif EXP / NWORDS == 31 || (EXP / NWORDS >= 23 && SLOPPY_MAXBPW >= 320) i32 w = x; // lowBits(x, bigwordBits = 32); *outCarry = ((i32)(x >> 32) + (w < 0)) << (32 - nBits); return w; From acbd0a2b1cc2de2b133e0859473761e23b32caba Mon Sep 17 00:00:00 2001 From: george Date: Fri, 3 Oct 2025 21:07:20 +0000 Subject: [PATCH 033/115] Fixed typo bug in FP32 code --- src/cl/ffthin.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/ffthin.cl b/src/cl/ffthin.cl index b579132b..cc519f96 100644 --- a/src/cl/ffthin.cl +++ b/src/cl/ffthin.cl @@ -43,7 +43,7 @@ KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { CP(F2) inF2 = (CP(F2)) in; P(F2) outF2 = (P(F2)) out; - TrigFP32 trigF2 = (TrigFP32) smallTrig; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; F2 u[NH]; u32 g = get_group_id(0); From 58e18e4e4dc1a4a423c4d8f46adecd49f81fdde6 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 4 Oct 2025 00:25:14 +0000 Subject: [PATCH 034/115] Fixed next bug in FP32 code. --- src/cl/ffthin.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/ffthin.cl b/src/cl/ffthin.cl index cc519f96..ae14ec53 100644 --- a/src/cl/ffthin.cl +++ b/src/cl/ffthin.cl @@ -57,7 +57,7 @@ KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { F2 w = slowTrig_N(ND / SMALL_HEIGHT * me, ND / NH); #endif - fft_HEIGHT(lds, u, smallTrigF2, w); + fft_HEIGHT(lds, u, smallTrigF2); write(G_H, NH, u, outF2, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); } From 3ccb6bf873cff9d77827eb5113d5b529c232ef46 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 4 Oct 2025 02:06:04 +0000 Subject: [PATCH 035/115] Make all FFTs and NTTs accessible from FFT specs in a single executable! Tune.txt enhanced to pick either FP64 FFTs or integer NTTs. --- src/Args.cpp | 61 ++--- src/Args.h | 3 +- src/FFTConfig.cpp | 127 +++++---- src/FFTConfig.h | 34 ++- src/Gpu.cpp | 593 ++++++++++++++++++++----------------------- src/Gpu.h | 45 ++-- src/Task.cpp | 47 ++-- src/Task.h | 7 +- src/TrigBufCache.cpp | 505 ++++++++++++++++++------------------ src/TrigBufCache.h | 24 +- src/cl/base.cl | 49 ++-- src/cl/fftbase.cl | 12 +- src/cl/tailmul.cl | 4 +- src/cl/tailsquare.cl | 40 +-- src/cl/tailutil.cl | 11 +- src/common.h | 13 +- src/fftbpw.h | 98 +++++++ src/tune.cpp | 589 +++++++++++++++++++++++++++++++----------- src/tune.h | 4 +- 19 files changed, 1308 insertions(+), 958 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 17aa6738..10201957 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -196,21 +196,6 @@ named "config.txt" in the prpll run directory. -tune : measures the speed of the FFTs specified in -fft to find the best FFT for each exponent. --ctune : finds the best configuration for each FFT specified in -fft . - Prints the results in a form that can be incorporated in config.txt - -fft 6.5M -ctune "OUT_SIZEX=32,8;OUT_WG=64,128,256" - - It is possible to specify -ctune multiple times on the same command in order to define multiple - sets of parameters to be combined, e.g.: - -ctune "IN_WG=256,128,64" -ctune "OUT_WG=256,64;OUT_SIZEX=32,16,8" - which would try only 8 combinations among those two sets. - - The tunable parameters (with the default value emphasized) are: - IN_WG, OUT_WG: 64, 128, *256* - IN_SIZEX, OUT_SIZEX: 4, 8, 16, *32* - UNROLL_W: *0*, 1 - UNROLL_H: 0, 1 - -device : select the GPU at position N in the list of devices -uid : select the GPU with the given UID (on ROCm/AMDGPU, Linux) -pci : select the GPU with the given PCI BDF, e.g. "0c:00.0" @@ -236,31 +221,34 @@ Device selection : use one of -uid , -pci , -device , see the list ); } - printf("\nFFT Configurations (specify with -fft :: from the set below):\n" + printf("\nFFT Configurations (specify with -fft ::: from the set below):\n" " Size MaxExp BPW FFT\n"); - + vector configs = FFTShape::allShapes(); configs.push_back(configs.front()); // dummy guard for the loop below. - string variants; u32 activeSize = 0; - double maxBpw = 0; - for (auto c : configs) { - if (c.size() != activeSize) { - if (!variants.empty()) { - printf("%5s %7.2fM %.2f %s\n", - numberK(activeSize).c_str(), - // activeSize * FFTShape::MIN_BPW / 1'000'000, - activeSize * maxBpw / 1'000'000.0, - maxBpw, - variants.c_str()); - variants.clear(); + float maxBpw = 0; + string variants; + for (enum FFT_TYPES type : {FFT64, FFT3161, FFT3261, FFT61}) { + for (auto c : configs) { + if (c.fft_type != type) continue; + if (c.size() != activeSize) { + if (!variants.empty()) { + printf("%5s %7.2fM %.2f %s\n", + numberK(activeSize).c_str(), + // activeSize * FFTShape::MIN_BPW / 1'000'000, + activeSize * maxBpw / 1'000'000.0, + maxBpw, + variants.c_str()); + variants.clear(); + } + activeSize = c.size(); + maxBpw = 0; } - activeSize = c.size(); - maxBpw = 0; + maxBpw = max(maxBpw, c.maxBpw()); + if (!variants.empty()) { variants.push_back(','); } + variants += c.spec(); } - maxBpw = max(maxBpw, c.maxBpw()); - if (!variants.empty()) { variants.push_back(','); } - variants += c.spec(); } } @@ -295,9 +283,10 @@ void Args::parse(const string& line) { log("-info expects an FFT spec, e.g. -info 1K:13:256\n"); throw "-info "; } - log(" FFT | BPW | Max exp (M)\n"); + log(" FFT | BPW | Max exp (M)\n"); for (const FFTShape& shape : FFTShape::multiSpec(s)) { for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { + if (variant != LAST_VARIANT && shape.fft_type != FFT64) continue; FFTConfig fft{shape, variant, CARRY_AUTO}; log("%12s | %.2f | %5.1f\n", fft.spec().c_str(), fft.maxBpw(), fft.maxExp() / 1'000'000.0); } @@ -310,8 +299,8 @@ void Args::parse(const string& line) { assert(s.empty()); logROE = true; } else if (key == "-tune") { - assert(s.empty()); doTune = true; + if (!s.empty()) { tune = s; } } else if (key == "-ctune") { doCtune = true; if (!s.empty()) { ctune.push_back(s); } diff --git a/src/Args.h b/src/Args.h index c7e92259..64823d7b 100644 --- a/src/Args.h +++ b/src/Args.h @@ -43,6 +43,7 @@ class Args { string uid; string verifyPath; + string tune; vector ctune; bool doCtune{}; @@ -53,7 +54,7 @@ class Args { std::map flags; std::map> perFftConfig; - + int device = 0; bool safeMath = true; diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index 669f7eea..09c945ae 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -19,10 +19,10 @@ using namespace std; struct FftBpw { string fft; - array bpw; + array bpw; }; -map> BPW { +map> BPW { #include "fftbpw.h" }; @@ -39,6 +39,7 @@ u32 parseInt(const string& s) { } // namespace // Accepts: +// - a prefix indicating FFT_type (if not specified, default is FP64) // - a single config: 1K:13:256 // - a size: 6.5M // - a range of sizes: 6.5M-7M @@ -49,13 +50,19 @@ vector FFTShape::multiSpec(const string& iniSpec) { vector ret; for (const string &spec : split(iniSpec, ',')) { + enum FFT_TYPES fft_type = FFT64; auto parts = split(spec, ':'); + if (parseInt(parts[0]) < 60) { // Look for a prefix specifying the FFT type + fft_type = (enum FFT_TYPES) parseInt(parts[0]); + for (u32 i = 1; i < parts.size(); ++i) parts[i-1] = parts[i]; + parts.resize(parts.size() - 1); + } assert(parts.size() <= 3); if (parts.size() == 3) { u32 width = parseInt(parts[0]); u32 middle = parseInt(parts[1]); u32 height = parseInt(parts[2]); - ret.push_back({width, middle, height}); + ret.push_back({fft_type, width, middle, height}); continue; } assert(parts.size() == 1); @@ -76,13 +83,16 @@ vector FFTShape::multiSpec(const string& iniSpec) { vector FFTShape::allShapes(u32 sizeFrom, u32 sizeTo) { vector configs; - for (u32 width : {256, 512, 1024, 4096}) { - for (u32 height : {256, 512, 1024}) { - if (width == 256 && height == 1024) { continue; } // Skip because we prefer width >= height - for (u32 middle : {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { - u32 sz = width * height * middle * 2; - if (sizeFrom <= sz && sz <= sizeTo) { - configs.push_back({width, middle, height}); + for (enum FFT_TYPES type : {FFT64, FFT3161, FFT3261, FFT61}) { + for (u32 width : {256, 512, 1024, 4096}) { + for (u32 height : {256, 512, 1024}) { + if (width == 256 && height == 1024) { continue; } // Skip because we prefer width >= height + for (u32 middle : {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + if (type != FFT64 && (middle & (middle - 1))) continue; // Reject non-power-of-two NTTs + u32 sz = width * height * middle * 2; + if (sizeFrom <= sz && sz <= sizeTo) { + configs.push_back({type, width, middle, height}); + } } } } @@ -101,36 +111,27 @@ vector FFTShape::allShapes(u32 sizeFrom, u32 sizeTo) { FFTShape::FFTShape(const string& spec) { assert(!spec.empty()); + enum FFT_TYPES fft_type = FFT64; vector v = split(spec, ':'); + if (parseInt(v[0]) < 60) { // Look for a prefix specifying the FFT type + fft_type = (enum FFT_TYPES) parseInt(v[0]); + for (u32 i = 1; i < v.size(); ++i) v[i-1] = v[i]; + v.resize(v.size() - 1); + } assert(v.size() == 3); - *this = FFTShape{v.at(0), v.at(1), v.at(2)}; + *this = FFTShape{fft_type, v.at(0), v.at(1), v.at(2)}; } -FFTShape::FFTShape(const string& w, const string& m, const string& h) : - FFTShape{parseInt(w), parseInt(m), parseInt(h)} +FFTShape::FFTShape(enum FFT_TYPES t, const string& w, const string& m, const string& h) : + FFTShape{t, parseInt(w), parseInt(m), parseInt(h)} {} -double FFTShape::carry32BPW() const { - // The formula below was validated empirically with -carryTune - - // We observe that FFT 6.5M (1024:13:256) has safe carry32 up to 18.35 BPW - // while the 0.5*log2() models the impact of FFT size changes. - // We model carry with a Gumbel distrib similar to the one used for ROE, and measure carry with - // -use STATS=1. See -carryTune - -//GW: I have no idea why this is needed. Without it, -tune fails on FFT sizes from 256K to 1M -// Perhaps it has something to do with RNDVALdoubleToLong in carryutil -if (18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())) > 19.0) return 19.0; - - return 18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())); -} - -bool FFTShape::needsLargeCarry(u32 E) const { - return E / double(size()) > carry32BPW(); +FFTShape::FFTShape(u32 w, u32 m, u32 h) : + FFTShape(FFT64, w, m, h) { } -FFTShape::FFTShape(u32 w, u32 m, u32 h) : - width{w}, middle{m}, height{h} { +FFTShape::FFTShape(enum FFT_TYPES t, u32 w, u32 m, u32 h) : + fft_type{t}, width{w}, middle{m}, height{h} { assert(w && m && h); // Un-initialized shape, don't set BPW @@ -158,13 +159,32 @@ FFTShape::FFTShape(u32 w, u32 m, u32 h) : for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) bpw[j] -= 0.05; // Assume this fft spec is worse than measured fft specs if (this->isFavoredShape()) { // Don't output this warning message for non-favored shapes (we expect the BPW info to be missing) printf("BPW info for %s not found, defaults={", s.c_str()); - for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", bpw[j]); - printf("}\n"); + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", (double) bpw[j]); + printf("}\n"); } } } } +float FFTShape::carry32BPW() const { + // The formula below was validated empirically with -carryTune + + // We observe that FFT 6.5M (1024:13:256) has safe carry32 up to 18.35 BPW + // while the 0.5*log2() models the impact of FFT size changes. + // We model carry with a Gumbel distrib similar to the one used for ROE, and measure carry with + // -use STATS=1. See -carryTune + +//GW: I have no idea why this is needed. Without it, -tune fails on FFT sizes from 256K to 1M +// Perhaps it has something to do with RNDVALdoubleToLong in carryutil +if (18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())) > 19.0) return 19.0; + + return 18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())); +} + +bool FFTShape::needsLargeCarry(u32 E) const { + return E / double(size()) > carry32BPW(); +} + // Return TRUE for "favored" shapes. That is, those that are most likely to be useful. To save time in generating bpw data, only these favored // shapes have their bpw data pre-computed. Bpw for non-favored shapes is guessed from the bpw data we do have. Also. -tune will normally only // time favored shapes. These are the rules for deciding favored shapes: @@ -183,18 +203,24 @@ bool FFTShape::isFavoredShape() const { FFTConfig::FFTConfig(const string& spec) { auto v = split(spec, ':'); - // assert(v.size() == 1 || v.size() == 3 || v.size() == 4 || v.size() == 5); + + enum FFT_TYPES fft_type = FFT64; + if (parseInt(v[0]) < 60) { // Look for a prefix specifying the FFT type + fft_type = (enum FFT_TYPES) parseInt(v[0]); + for (u32 i = 1; i < v.size(); ++i) v[i-1] = v[i]; + v.resize(v.size() - 1); + } if (v.size() == 1) { *this = {FFTShape::multiSpec(spec).front(), LAST_VARIANT, CARRY_AUTO}; } if (v.size() == 3) { - *this = {FFTShape{v[0], v[1], v[2]}, LAST_VARIANT, CARRY_AUTO}; + *this = {FFTShape{fft_type, v[0], v[1], v[2]}, LAST_VARIANT, CARRY_AUTO}; } else if (v.size() == 4) { - *this = {FFTShape{v[0], v[1], v[2]}, parseInt(v[3]), CARRY_AUTO}; + *this = {FFTShape{fft_type, v[0], v[1], v[2]}, parseInt(v[3]), CARRY_AUTO}; } else if (v.size() == 5) { int c = parseInt(v[4]); assert(c == 0 || c == 1); - *this = {FFTShape{v[0], v[1], v[2]}, parseInt(v[3]), c == 0 ? CARRY_32 : CARRY_64}; + *this = {FFTShape{fft_type, v[0], v[1], v[2]}, parseInt(v[3]), c == 0 ? CARRY_32 : CARRY_64}; } else { throw "FFT spec"; } @@ -208,6 +234,19 @@ FFTConfig::FFTConfig(FFTShape shape, u32 variant, u32 carry) : assert(variant_W(variant) < N_VARIANT_W); assert(variant_M(variant) < N_VARIANT_M); assert(variant_H(variant) < N_VARIANT_H); + + if (shape.fft_type == FFT64) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 0; + else if (shape.fft_type == FFT3161) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 1; + else if (shape.fft_type == FFT3261) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 1; + else if (shape.fft_type == FFT61) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 1; + else if (shape.fft_type == FFT3231) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 0; + else if (shape.fft_type == FFT6431) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0; + else if (shape.fft_type == FFT31) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0; + else if (shape.fft_type == FFT32) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 0; + else throw "FFT type"; + + if ((FFT_FP64 && NTT_GF31) || (NTT_GF31 && NTT_GF61) || (FFT_FP32 && NTT_GF61)) WordSize = 8; + else WordSize = 4; } string FFTConfig::spec() const { @@ -215,8 +254,8 @@ string FFTConfig::spec() const { return carry == CARRY_AUTO ? s : (s + (carry == CARRY_32 ? ":0" : ":1")); } -double FFTConfig::maxBpw() const { - double b; +float FFTConfig::maxBpw() const { + float b; // Look up the pre-computed maximum bpw. The lookup table contains data for variants 000, 101, 202, 010, 111, 212. // For 4K width, the lookup table contains data for variants 100, 101, 202, 110, 111, 212 since BCAST only works for width <= 1024. if (variant_W(variant) == variant_H(variant) || @@ -225,8 +264,8 @@ double FFTConfig::maxBpw() const { } // Interpolate for the maximum bpw. This might could be improved upon. However, I doubt people will use these variants often. else { - double b1 = shape.bpw[variant_M(variant) * 3 + variant_W(variant)]; - double b2 = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; + float b1 = shape.bpw[variant_M(variant) * 3 + variant_W(variant)]; + float b2 = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; b = (b1 + b2) / 2.0; } return carry == CARRY_32 ? std::min(shape.carry32BPW(), b) : b; @@ -237,7 +276,7 @@ FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { if (!spec.empty()) { FFTConfig fft{spec}; if (fft.maxExp() * args.fftOverdrive < E) { - log("Warning: %s (max %u) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); + log("Warning: %s (max %lu) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); } return fft; } @@ -263,7 +302,7 @@ FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { } -string numberK(u32 n) { +string numberK(u64 n) { u32 K = 1024; u32 M = K * K; diff --git a/src/FFTConfig.h b/src/FFTConfig.h index dc65cad0..b737d20b 100644 --- a/src/FFTConfig.h +++ b/src/FFTConfig.h @@ -17,10 +17,12 @@ class Args; // Format 'n' with a K or M suffix if multiple of 1024 or 1024*1024 -string numberK(u32 n); +string numberK(u64 n); using KeyVal = std::pair; +enum FFT_TYPES {FFT64=0, FFT3161=1, FFT3261=2, FFT61=3, FFT3231=50, FFT6431=51, FFT31=52, FFT32=53}; + class FFTShape { public: static constexpr const float MIN_BPW = 3; @@ -31,23 +33,25 @@ class FFTShape { static vector multiSpec(const string& spec); + enum FFT_TYPES fft_type; u32 width = 0; u32 middle = 0; u32 height = 0; - array bpw; + array bpw; FFTShape(u32 w = 1, u32 m = 1, u32 h = 1); - FFTShape(const string& w, const string& m, const string& h); + FFTShape(enum FFT_TYPES t, u32 w, u32 m, u32 h); + FFTShape(enum FFT_TYPES t, const string& w, const string& m, const string& h); explicit FFTShape(const string& spec); u32 size() const { return width * height * middle * 2; } u32 nW() const { return (width == 1024 || width == 256 /*|| width == 4096*/) ? 4 : 8; } u32 nH() const { return (height == 1024 || height == 256 /*|| height == 4096*/) ? 4 : 8; } - double maxBpw() const { return *max_element(bpw.begin(), bpw.end()); } - std::string spec() const { return numberK(width) + ':' + numberK(middle) + ':' + numberK(height); } + float maxBpw() const { return *max_element(bpw.begin(), bpw.end()); } + std::string spec() const { return (fft_type ? to_string(fft_type) + ':' : "") + numberK(width) + ':' + numberK(middle) + ':' + numberK(height); } - double carry32BPW() const; + float carry32BPW() const; bool needsLargeCarry(u32 E) const; bool isFavoredShape() const; }; @@ -66,12 +70,22 @@ inline u32 next_variant(u32 v) { new_v = (v / 100 + 1) * 100; return (new_v); } -enum CARRY_KIND { CARRY_32=0, CARRY_64=1, CARRY_AUTO=2}; +enum CARRY_KIND {CARRY_32=0, CARRY_64=1, CARRY_AUTO=2}; struct FFTConfig { public: static FFTConfig bestFit(const Args& args, u32 E, const std::string& spec); + // Which FP and NTT primes are involved in the FFT + bool FFT_FP64; + bool FFT_FP32; + bool NTT_GF31; + bool NTT_GF61; + // bool NTT_NCW; // Nick Craig-Wood prime not supported (yet?) + + // Size (in bytes) of integer data passed to FFTs/NTTs on the GPU + u32 WordSize; + FFTShape shape{}; u32 variant; u32 carry; @@ -80,8 +94,8 @@ struct FFTConfig { FFTConfig(FFTShape shape, u32 variant, u32 carry); std::string spec() const; - u32 size() const { return shape.size(); } - u32 maxExp() const { return maxBpw() * shape.size(); } + u64 size() const { return shape.size(); } + u64 maxExp() const { return maxBpw() * shape.size(); } - double maxBpw() const; + float maxBpw() const; }; diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 16fb26b6..e2ce780b 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #define _USE_MATH_DEFINES #include @@ -39,8 +40,6 @@ namespace { -#if FFT_FP64 - u32 kAt(u32 H, u32 line, u32 col) { return (line + col * H) * 2; } double weight(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { @@ -61,128 +60,106 @@ double invWeightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { double boundUnderOne(double x) { return std::min(x, nexttoward(1, 0)); } -Weights genWeights(u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { - u32 N = 2u * W * H; - - u32 groupWidth = W / nW; - - // Inverse + Forward - vector weightsConstIF; - vector weightsIF; - for (u32 thread = 0; thread < groupWidth; ++thread) { - auto iw = invWeight(N, E, H, 0, thread, 0); - auto w = weight(N, E, H, 0, thread, 0); - // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer - // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF - // buffer, so there is an unfortunate duplication of these weights. - if (!AmdGpu) { - weightsConstIF.push_back(2 * boundUnderOne(iw)); - weightsConstIF.push_back(2 * w); - } - weightsIF.push_back(2 * boundUnderOne(iw)); - weightsIF.push_back(2 * w); - } - - // the group order matches CarryA/M (not fftP/CarryFused). - for (u32 gy = 0; gy < H; ++gy) { - weightsIF.push_back(invWeightM1(N, E, H, gy, 0, 0)); - weightsIF.push_back(weightM1(N, E, H, gy, 0, 0)); - } - - vector bits; - - for (u32 line = 0; line < H; ++line) { - for (u32 thread = 0; thread < groupWidth; ) { - std::bitset<32> b; - for (u32 bitoffset = 0; bitoffset < 32; bitoffset += nW*2, ++thread) { - for (u32 block = 0; block < nW; ++block) { - for (u32 rep = 0; rep < 2; ++rep) { - if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } - } - } - } - bits.push_back(b.to_ulong()); - } - } - assert(bits.size() == N / 32); - - return Weights{weightsConstIF, weightsIF, bits}; -} - -#endif - -#if FFT_FP32 - -u32 kAt(u32 H, u32 line, u32 col) { return (line + col * H) * 2; } - -float weight(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { +float weight32(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { return exp2((double)(extra(N, E, kAt(H, line, col) + rep)) / N); } -float invWeight(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { +float invWeight32(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { return exp2(-(double)(extra(N, E, kAt(H, line, col) + rep)) / N); } -float weightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { +float weightM132(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { return exp2((double)(extra(N, E, kAt(H, line, col) + rep)) / N) - 1; } -float invWeightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { +float invWeightM132(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { return exp2(- (double)(extra(N, E, kAt(H, line, col) + rep)) / N) - 1; } float boundUnderOne(float x) { return std::min(x, nexttowardf(1, 0)); } -Weights genWeights(u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { +Weights genWeights(FFTConfig fft, u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { u32 N = 2u * W * H; - u32 groupWidth = W / nW; - // Inverse + Forward - vector weightsConstIF; - vector weightsIF; - for (u32 thread = 0; thread < groupWidth; ++thread) { - auto iw = invWeight(N, E, H, 0, thread, 0); - auto w = weight(N, E, H, 0, thread, 0); - // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer - // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF - // buffer, so there is an unfortunate duplication of these weights. - if (!AmdGpu) { - weightsConstIF.push_back(2 * boundUnderOne(iw)); - weightsConstIF.push_back(2 * w); + vector weightsConstIF; + vector weightsIF; + vector bits; + + if (fft.FFT_FP64) { + // Inverse + Forward + for (u32 thread = 0; thread < groupWidth; ++thread) { + auto iw = invWeight(N, E, H, 0, thread, 0); + auto w = weight(N, E, H, 0, thread, 0); + // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer + // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF + // buffer, so there is an unfortunate duplication of these weights. + if (!AmdGpu) { + weightsConstIF.push_back(2 * boundUnderOne(iw)); + weightsConstIF.push_back(2 * w); + } + weightsIF.push_back(2 * boundUnderOne(iw)); + weightsIF.push_back(2 * w); } - weightsIF.push_back(2 * boundUnderOne(iw)); - weightsIF.push_back(2 * w); - } - // the group order matches CarryA/M (not fftP/CarryFused). - for (u32 gy = 0; gy < H; ++gy) { - weightsIF.push_back(invWeightM1(N, E, H, gy, 0, 0)); - weightsIF.push_back(weightM1(N, E, H, gy, 0, 0)); + // the group order matches CarryA/M (not fftP/CarryFused). + for (u32 gy = 0; gy < H; ++gy) { + weightsIF.push_back(invWeightM1(N, E, H, gy, 0, 0)); + weightsIF.push_back(weightM1(N, E, H, gy, 0, 0)); + } } - - vector bits; - for (u32 line = 0; line < H; ++line) { - for (u32 thread = 0; thread < groupWidth; ) { - std::bitset<32> b; - for (u32 bitoffset = 0; bitoffset < 32; bitoffset += nW*2, ++thread) { - for (u32 block = 0; block < nW; ++block) { - for (u32 rep = 0; rep < 2; ++rep) { - if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } - } - } + else if (fft.FFT_FP32) { + vector weightsConstIF32; + vector weightsIF32; + // Inverse + Forward + for (u32 thread = 0; thread < groupWidth; ++thread) { + auto iw = invWeight32(N, E, H, 0, thread, 0); + auto w = weight32(N, E, H, 0, thread, 0); + // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer + // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF + // buffer, so there is an unfortunate duplication of these weights. + if (!AmdGpu) { + weightsConstIF32.push_back(2 * boundUnderOne(iw)); + weightsConstIF32.push_back(2 * w); + } + weightsIF32.push_back(2 * boundUnderOne(iw)); + weightsIF32.push_back(2 * w); + } + + // the group order matches CarryA/M (not fftP/CarryFused). + for (u32 gy = 0; gy < H; ++gy) { + weightsIF32.push_back(invWeightM132(N, E, H, gy, 0, 0)); + weightsIF32.push_back(weightM132(N, E, H, gy, 0, 0)); + } + + // Copy the float vectors to the double vectors + weightsConstIF.resize(weightsConstIF32.size() / 2); + memcpy((double *) weightsConstIF.data(), weightsConstIF32.data(), weightsConstIF32.size() * sizeof(float)); + weightsIF.resize(weightsIF32.size() / 2); + memcpy((double *) weightsIF.data(), weightsIF32.data(), weightsIF32.size() * sizeof(float)); + } + + if (fft.FFT_FP64 || fft.FFT_FP64) { + for (u32 line = 0; line < H; ++line) { + for (u32 thread = 0; thread < groupWidth; ) { + std::bitset<32> b; + for (u32 bitoffset = 0; bitoffset < 32; bitoffset += nW*2, ++thread) { + for (u32 block = 0; block < nW; ++block) { + for (u32 rep = 0; rep < 2; ++rep) { + if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } + } + } + } + bits.push_back(b.to_ulong()); } - bits.push_back(b.to_ulong()); } + assert(bits.size() == N / 32); } - assert(bits.size() == N / 32); return Weights{weightsConstIF, weightsIF, bits}; } -#endif - string toLiteral(i32 value) { return to_string(value); } string toLiteral(u32 value) { return to_string(value) + 'u'; } [[maybe_unused]] string toLiteral(i64 value) { return to_string(value) + "l"; } @@ -249,7 +226,7 @@ constexpr bool isInList(const string& s, initializer_list list) { } string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector& extraConf, u32 E, bool doLog, - bool &tail_single_wide, bool &tail_single_kernel, u32 &tail_trigs, u32 &pad_size) { + bool &tail_single_wide, bool &tail_single_kernel, u32 &pad_size) { map config; // Highest priority is the requested "extra" conf @@ -266,7 +243,6 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< // Default value for -use options that must also be parsed in C++ code tail_single_wide = 0, tail_single_kernel = 1; // Default tailSquare is double-wide in one kernel - tail_trigs = 2; // Default is calculating from scratch, no memory accesses pad_size = isAmdGpu(id) ? 256 : 0; // Default is 256 bytes for AMD, 0 for others // Validate -use options @@ -292,7 +268,13 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< "MIDDLE_OUT_LDS_TRANSPOSE", "TAIL_KERNELS", "TAIL_TRIGS", - "TABMUL_CHAIN" + "TAIL_TRIGS31", + "TAIL_TRIGS32", + "TAIL_TRIGS61", + "TABMUL_CHAIN", + "TABMUL_CHAIN31", + "TABMUL_CHAIN32", + "TABMUL_CHAIN61" }); if (!isValid) { log("Warning: unrecognized -use key '%s'\n", k.c_str()); @@ -305,7 +287,6 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< if (atoi(v.c_str()) == 2) tail_single_wide = 0, tail_single_kernel = 1; if (atoi(v.c_str()) == 3) tail_single_wide = 0, tail_single_kernel = 0; } - if (k == "TAIL_TRIGS") tail_trigs = atoi(v.c_str()); if (k == "PAD") pad_size = atoi(v.c_str()); } @@ -332,64 +313,77 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< u32 N = fft.shape.size(); defines += toDefine("FFT_VARIANT", fft.variant); -#if FFT_FP64 | FFT_FP32 - defines += toDefine("WEIGHT_STEP", weightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); - defines += toDefine("IWEIGHT_STEP", invWeightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); - defines += toDefine("TAILT", root1Fancy(fft.shape.height * 2, 1)); - - TrigCoefs coefs = trigCoefs(fft.shape.size() / 4); - defines += toDefine("TRIG_SCALE", int(coefs.scale)); - defines += toDefine("TRIG_SIN", coefs.sinCoefs); - defines += toDefine("TRIG_COS", coefs.cosCoefs); -#endif -#if NTT_GF31 - defines += toDefine("TAILTGF31", root1GF31(fft.shape.height * 2, 1)); -#endif -#if NTT_GF61 - defines += toDefine("TAILTGF61", root1GF61(fft.shape.height * 2, 1)); -#endif - -// When using multiple NTT primes or hybrid FFT/NTT, each FFT/NTT prime's data buffer and trig values are combined into one buffer. -// The openCL code needs to know the offset to the data and trig values. Distances are in "number of double2 values". -#if FFT_FP64 & NTT_GF31 - // GF31 data is located after the FP64 data. Compute size of the FP64 data and trigs. - defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); - defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); - defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); - defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); -#elif FFT_FP32 & NTT_GF31 - // GF31 data is located after the FP32 data. Compute size of the FP32 data and trigs. - defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); - defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); - defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); - defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); -#elif FFT_FP32 & NTT_GF61 - // GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. - defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); - defines += toDefine("DISTWTRIGGF61", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); - defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); - defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); -#elif NTT_GF31 & NTT_GF61 - defines += toDefine("DISTGF31", 0); - defines += toDefine("DISTWTRIGGF31", 0); - defines += toDefine("DISTMTRIGGF31", 0); - defines += toDefine("DISTHTRIGGF31", 0); - // GF61 data is located after the GF31 data. Compute size of the GF31 data and trigs. - defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); - defines += toDefine("DISTWTRIGGF61", SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); - defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); - defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); -#elif NTT_GF31 - defines += toDefine("DISTGF31", 0); - defines += toDefine("DISTWTRIGGF31", 0); - defines += toDefine("DISTMTRIGGF31", 0); - defines += toDefine("DISTHTRIGGF31", 0); -#elif NTT_GF61 - defines += toDefine("DISTGF61", 0); - defines += toDefine("DISTWTRIGGF61", 0); - defines += toDefine("DISTMTRIGGF61", 0); - defines += toDefine("DISTHTRIGGF61", 0); -#endif + if (fft.FFT_FP64 | fft.FFT_FP32) { + defines += toDefine("WEIGHT_STEP", weightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); + defines += toDefine("IWEIGHT_STEP", invWeightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); + if (fft.FFT_FP64) defines += toDefine("TAILT", root1Fancy(fft.shape.height * 2, 1)); + else defines += toDefine("TAILT", root1FancyFP32(fft.shape.height * 2, 1)); + + TrigCoefs coefs = trigCoefs(fft.shape.size() / 4); + defines += toDefine("TRIG_SCALE", int(coefs.scale)); + defines += toDefine("TRIG_SIN", coefs.sinCoefs); + defines += toDefine("TRIG_COS", coefs.cosCoefs); + } + if (fft.NTT_GF31) { + defines += toDefine("TAILTGF31", root1GF31(fft.shape.height * 2, 1)); + } + if (fft.NTT_GF61) { + defines += toDefine("TAILTGF61", root1GF61(fft.shape.height * 2, 1)); + } + + // Enable/disable code for each possible FP and NTT + defines += toDefine("FFT_FP64", (int) fft.FFT_FP64); + defines += toDefine("FFT_FP32", (int) fft.FFT_FP32); + defines += toDefine("NTT_GF31", (int) fft.NTT_GF31); + defines += toDefine("NTT_GF61", (int) fft.NTT_GF61); + defines += toDefine("WordSize", fft.WordSize); + + // When using multiple NTT primes or hybrid FFT/NTT, each FFT/NTT prime's data buffer and trig values are combined into one buffer. + // The openCL code needs to know the offset to the data and trig values. Distances are in "number of double2 values". + if (fft.FFT_FP64 && fft.NTT_GF31) { + // GF31 data is located after the FP64 data. Compute size of the FP64 data and trigs. + defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.FFT_FP32 && fft.NTT_GF31) { + // GF31 data is located after the FP32 data. Compute size of the FP32 data and trigs. + defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.FFT_FP32 && fft.NTT_GF61) { + // GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. + defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.NTT_GF31 && fft.NTT_GF61) { + defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTWTRIGGF31", 0); + defines += toDefine("DISTMTRIGGF31", 0); + defines += toDefine("DISTHTRIGGF31", 0); + // GF61 data is located after the GF31 data. Compute size of the GF31 data and trigs. + defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.NTT_GF31) { + defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTWTRIGGF31", 0); + defines += toDefine("DISTMTRIGGF31", 0); + defines += toDefine("DISTHTRIGGF31", 0); + } + else if (fft.NTT_GF61) { + defines += toDefine("DISTGF61", 0); + defines += toDefine("DISTWTRIGGF61", 0); + defines += toDefine("DISTMTRIGGF61", 0); + defines += toDefine("DISTHTRIGGF61", 0); + } // Calculate fractional bits-per-word = (E % N) / N * 2^64 u32 bpw_hi = (u64(E % N) << 32) / N; @@ -511,6 +505,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& args{*shared.args}, E(E), N(fft.shape.size()), + fft(fft), WIDTH(fft.shape.width), SMALL_H(fft.shape.height), BIG_H(SMALL_H * fft.shape.middle), @@ -518,12 +513,10 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& nW(fft.shape.nW()), nH(fft.shape.nH()), useLongCarry{args.carry == Args::CARRY_LONG}, - compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, - tail_single_wide, tail_single_kernel, tail_trigs, pad_size)}, + compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, tail_single_wide, tail_single_kernel, pad_size)}, #define K(name, ...) name(#name, &compiler, profile.make(#name), queue, __VA_ARGS__) -#if FFT_FP64 | FFT_FP32 K(kfftMidIn, "fftmiddlein.cl", "fftMiddleIn", hN / (BIG_H / SMALL_H)), K(kfftHin, "ffthin.cl", "fftHin", hN / nH), K(ktailSquareZero, "tailsquare.cl", "tailSquareZero", SMALL_H / nH * 2), @@ -536,9 +529,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& K(ktailMulLow, "tailmul.cl", "tailMul", hN / nH / 2, "-DMUL_LOW=1"), K(kfftMidOut, "fftmiddleout.cl", "fftMiddleOut", hN / (BIG_H / SMALL_H)), K(kfftW, "fftw.cl", "fftW", hN / nW), -#endif -#if NTT_GF31 K(kfftMidInGF31, "fftmiddlein.cl", "fftMiddleInGF31", hN / (BIG_H / SMALL_H)), K(kfftHinGF31, "ffthin.cl", "fftHinGF31", hN / nH), K(ktailSquareZeroGF31, "tailsquare.cl", "tailSquareZeroGF31", SMALL_H / nH * 2), @@ -551,9 +542,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& K(ktailMulLowGF31, "tailmul.cl", "tailMulGF31", hN / nH / 2, "-DMUL_LOW=1"), K(kfftMidOutGF31, "fftmiddleout.cl", "fftMiddleOutGF31", hN / (BIG_H / SMALL_H)), K(kfftWGF31, "fftw.cl", "fftWGF31", hN / nW), -#endif -#if NTT_GF61 K(kfftMidInGF61, "fftmiddlein.cl", "fftMiddleInGF61", hN / (BIG_H / SMALL_H)), K(kfftHinGF61, "ffthin.cl", "fftHinGF61", hN / nH), K(ktailSquareZeroGF61, "tailsquare.cl", "tailSquareZeroGF61", SMALL_H / nH * 2), @@ -566,7 +555,6 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& K(ktailMulLowGF61, "tailmul.cl", "tailMulGF61", hN / nH / 2, "-DMUL_LOW=1"), K(kfftMidOutGF61, "fftmiddleout.cl", "fftMiddleOutGF61", hN / (BIG_H / SMALL_H)), K(kfftWGF61, "fftw.cl", "fftWGF61", hN / nW), -#endif K(kfftP, "fftp.cl", "fftP", hN / nW), K(kCarryA, "carry.cl", "carry", hN / CARRY_LEN), @@ -592,33 +580,30 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& K(kernIsEqual, "etc.cl", "isEqual", 256 * 256, "-DISEQUAL=1"), K(sum64, "etc.cl", "sum64", 256 * 256, "-DSUM64=1"), -#if FFT_FP64 K(testTrig, "selftest.cl", "testTrig", 256 * 256), K(testFFT4, "selftest.cl", "testFFT4", 256), K(testFFT14, "selftest.cl", "testFFT14", 256), K(testFFT15, "selftest.cl", "testFFT15", 256), K(testFFT, "selftest.cl", "testFFT", 256), -#endif K(testTime, "selftest.cl", "testTime", 4096 * 64), #undef K - bufTrigH{shared.bufCache->smallTrigCombo(WIDTH, fft.shape.middle, SMALL_H, nH, fft.variant, tail_single_wide, tail_trigs)}, - bufTrigM{shared.bufCache->middleTrig(SMALL_H, BIG_H / SMALL_H, WIDTH)}, - bufTrigW{shared.bufCache->smallTrig(WIDTH, nW, fft.shape.middle, SMALL_H, nH, fft.variant, tail_single_wide, tail_trigs)}, + bufTrigH{shared.bufCache->smallTrigCombo(shared.args, fft, WIDTH, fft.shape.middle, SMALL_H, nH, tail_single_wide)}, + bufTrigM{shared.bufCache->middleTrig(shared.args, fft, SMALL_H, BIG_H / SMALL_H, WIDTH)}, + bufTrigW{shared.bufCache->smallTrig(shared.args, fft, WIDTH, nW, fft.shape.middle, SMALL_H, nH, tail_single_wide)}, -#if FFT_FP64 | FFT_FP32 - weights{genWeights(E, WIDTH, BIG_H, nW, isAmdGpu(q->context->deviceId()))}, + weights{genWeights(fft, E, WIDTH, BIG_H, nW, isAmdGpu(q->context->deviceId()))}, bufConstWeights{q->context, std::move(weights.weightsConstIF)}, bufWeights{q->context, std::move(weights.weightsIF)}, bufBits{q->context, std::move(weights.bitsCF)}, -#endif #define BUF(name, ...) name{profile.make(#name), queue, __VA_ARGS__} - BUF(bufData, N), - BUF(bufAux, N), - BUF(bufCheck, N), + // GPU Buffers containing integer data. Since this buffer is type i64, if fft.WordSize < 8 then we need less memory allocated. + BUF(bufData, N * fft.WordSize / sizeof(Word)), + BUF(bufAux, N * fft.WordSize / sizeof(Word)), + BUF(bufCheck, N * fft.WordSize / sizeof(Word)), // Every double-word (i.e. N/2) produces one carry. In addition we may have one extra group thus WIDTH more carries. BUF(bufCarry, N / 2 + WIDTH), BUF(bufReady, (N / 2 + WIDTH) / 32), // Every wavefront (32 or 64 lanes) needs to signal "carry is ready" @@ -629,9 +614,9 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& BUF(bufROE, ROE_SIZE), BUF(bufStatsCarry, CARRY_SIZE), - BUF(buf1, TOTAL_DATA_SIZE(WIDTH, fft.shape.middle, SMALL_H, pad_size)), - BUF(buf2, TOTAL_DATA_SIZE(WIDTH, fft.shape.middle, SMALL_H, pad_size)), - BUF(buf3, TOTAL_DATA_SIZE(WIDTH, fft.shape.middle, SMALL_H, pad_size)), + BUF(buf1, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, pad_size)), + BUF(buf2, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, pad_size)), + BUF(buf3, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, pad_size)), #undef BUF statsBits{u32(args.value("STATS", 0))}, @@ -645,7 +630,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& // Sometimes we do want to run a FFT beyond a reasonable BPW (e.g. during -ztune), and these situations // coincide with logFftSize == false if (fft.maxExp() < E) { - log("Warning: %s (max %u) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); + log("Warning: %s (max %lu) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); } } @@ -656,64 +641,64 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& } #endif - useLongCarry = useLongCarry || (bitsPerWord < 12.0); + useLongCarry = useLongCarry || (bitsPerWord < 10.0); if (useLongCarry) { log("Using long carry!\n"); } -#if FFT_FP64 | FFT_FP32 - kfftMidIn.setFixedArgs(2, bufTrigM); - kfftHin.setFixedArgs(2, bufTrigH); - ktailSquareZero.setFixedArgs(2, bufTrigH); - ktailSquare.setFixedArgs(2, bufTrigH); - ktailMulLow.setFixedArgs(3, bufTrigH); - ktailMul.setFixedArgs(3, bufTrigH); - kfftMidOut.setFixedArgs(2, bufTrigM); - kfftW.setFixedArgs(2, bufTrigW); -#endif - -#if NTT_GF31 - kfftMidInGF31.setFixedArgs(2, bufTrigM); - kfftHinGF31.setFixedArgs(2, bufTrigH); - ktailSquareZeroGF31.setFixedArgs(2, bufTrigH); - ktailSquareGF31.setFixedArgs(2, bufTrigH); - ktailMulLowGF31.setFixedArgs(3, bufTrigH); - ktailMulGF31.setFixedArgs(3, bufTrigH); - kfftMidOutGF31.setFixedArgs(2, bufTrigM); - kfftWGF31.setFixedArgs(2, bufTrigW); -#endif - -#if NTT_GF61 - kfftMidInGF61.setFixedArgs(2, bufTrigM); - kfftHinGF61.setFixedArgs(2, bufTrigH); - ktailSquareZeroGF61.setFixedArgs(2, bufTrigH); - ktailSquareGF61.setFixedArgs(2, bufTrigH); - ktailMulLowGF61.setFixedArgs(3, bufTrigH); - ktailMulGF61.setFixedArgs(3, bufTrigH); - kfftMidOutGF61.setFixedArgs(2, bufTrigM); - kfftWGF61.setFixedArgs(2, bufTrigW); -#endif - -#if FFT_FP64 | FFT_FP32 // The FP versions take bufWeight arguments (and bufBits which may be deleted) - kfftP.setFixedArgs(2, bufTrigW, bufWeights); - for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry, bufWeights); } - for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(5, bufStatsCarry); } - for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(5, bufROE); } - for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { - k->setFixedArgs(3, bufCarry, bufReady, bufTrigW, bufBits, bufConstWeights, bufWeights); + if (fft.FFT_FP64 || fft.FFT_FP32) { + kfftMidIn.setFixedArgs(2, bufTrigM); + kfftHin.setFixedArgs(2, bufTrigH); + ktailSquareZero.setFixedArgs(2, bufTrigH); + ktailSquare.setFixedArgs(2, bufTrigH); + ktailMulLow.setFixedArgs(3, bufTrigH); + ktailMul.setFixedArgs(3, bufTrigH); + kfftMidOut.setFixedArgs(2, bufTrigM); + kfftW.setFixedArgs(2, bufTrigW); + } + + if (fft.NTT_GF31) { + kfftMidInGF31.setFixedArgs(2, bufTrigM); + kfftHinGF31.setFixedArgs(2, bufTrigH); + ktailSquareZeroGF31.setFixedArgs(2, bufTrigH); + ktailSquareGF31.setFixedArgs(2, bufTrigH); + ktailMulLowGF31.setFixedArgs(3, bufTrigH); + ktailMulGF31.setFixedArgs(3, bufTrigH); + kfftMidOutGF31.setFixedArgs(2, bufTrigM); + kfftWGF31.setFixedArgs(2, bufTrigW); + } + + if (fft.NTT_GF61) { + kfftMidInGF61.setFixedArgs(2, bufTrigM); + kfftHinGF61.setFixedArgs(2, bufTrigH); + ktailSquareZeroGF61.setFixedArgs(2, bufTrigH); + ktailSquareGF61.setFixedArgs(2, bufTrigH); + ktailMulLowGF61.setFixedArgs(3, bufTrigH); + ktailMulGF61.setFixedArgs(3, bufTrigH); + kfftMidOutGF61.setFixedArgs(2, bufTrigM); + kfftWGF61.setFixedArgs(2, bufTrigW); + } + + if (fft.FFT_FP64 || fft.FFT_FP32) { // The FP versions take bufWeight arguments (and bufBits which may be deleted) + kfftP.setFixedArgs(2, bufTrigW, bufWeights); + for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry, bufWeights); } + for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(5, bufStatsCarry); } + for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(5, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { + k->setFixedArgs(3, bufCarry, bufReady, bufTrigW, bufBits, bufConstWeights, bufWeights); + } + for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(9, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(9, bufStatsCarry); } + } else { + kfftP.setFixedArgs(2, bufTrigW); + for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry); } + for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(4, bufStatsCarry); } + for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(4, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { + k->setFixedArgs(3, bufCarry, bufReady, bufTrigW); + } + for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(6, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(6, bufStatsCarry); } } - for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(9, bufROE); } - for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(9, bufStatsCarry); } -#else - kfftP.setFixedArgs(2, bufTrigW); - for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry); } - for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(4, bufStatsCarry); } - for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(4, bufROE); } - for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { - k->setFixedArgs(3, bufCarry, bufReady, bufTrigW); - } - for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(6, bufROE); } - for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(6, bufStatsCarry); } -#endif carryB.setFixedArgs(1, bufCarry); @@ -739,99 +724,51 @@ void Gpu::fftP(Buffer& out, Buffer& in) { } void Gpu::fftW(Buffer& out, Buffer& in) { -#if FFT_FP64 | FFT_FP32 - kfftW(out, in); -#endif -#if NTT_GF31 - kfftWGF31(out, in); -#endif -#if NTT_GF61 - kfftWGF61(out, in); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) kfftW(out, in); + if (fft.NTT_GF31) kfftWGF31(out, in); + if (fft.NTT_GF61) kfftWGF61(out, in); } void Gpu::fftMidIn(Buffer& out, Buffer& in) { -#if FFT_FP64 | FFT_FP32 - kfftMidIn(out, in); -#endif -#if NTT_GF31 - kfftMidInGF31(out, in); -#endif -#if NTT_GF61 - kfftMidInGF61(out, in); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) kfftMidIn(out, in); + if (fft.NTT_GF31) kfftMidInGF31(out, in); + if (fft.NTT_GF61) kfftMidInGF61(out, in); } void Gpu::fftMidOut(Buffer& out, Buffer& in) { -#if FFT_FP64 | FFT_FP32 - kfftMidOut(out, in); -#endif -#if NTT_GF31 - kfftMidOutGF31(out, in); -#endif -#if NTT_GF61 - kfftMidOutGF61(out, in); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) kfftMidOut(out, in); + if (fft.NTT_GF31) kfftMidOutGF31(out, in); + if (fft.NTT_GF61) kfftMidOutGF61(out, in); } void Gpu::fftHin(Buffer& out, Buffer& in) { -#if FFT_FP64 | FFT_FP32 - kfftHin(out, in); -#endif -#if NTT_GF31 - kfftHinGF31(out, in); -#endif -#if NTT_GF61 - kfftHinGF61(out, in); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) kfftHin(out, in); + if (fft.NTT_GF31) kfftHinGF31(out, in); + if (fft.NTT_GF61) kfftHinGF61(out, in); } void Gpu::tailSquareZero(Buffer& out, Buffer& in) { -#if FFT_FP64 | FFT_FP32 - ktailSquareZero(out, in); -#endif -#if NTT_GF31 - ktailSquareZeroGF31(out, in); -#endif -#if NTT_GF61 - ktailSquareZeroGF61(out, in); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquareZero(out, in); + if (fft.NTT_GF31) ktailSquareZeroGF31(out, in); + if (fft.NTT_GF61) ktailSquareZeroGF61(out, in); } void Gpu::tailSquare(Buffer& out, Buffer& in) { -#if FFT_FP64 | FFT_FP32 - ktailSquare(out, in); -#endif -#if NTT_GF31 - ktailSquareGF31(out, in); -#endif -#if NTT_GF61 - ktailSquareGF61(out, in); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquare(out, in); + if (fft.NTT_GF31) ktailSquareGF31(out, in); + if (fft.NTT_GF61) ktailSquareGF61(out, in); } void Gpu::tailMul(Buffer& out, Buffer& in1, Buffer& in2) { -#if FFT_FP64 | FFT_FP32 - ktailMul(out, in1, in2); -#endif -#if NTT_GF31 - ktailMulGF31(out, in1, in2); -#endif -#if NTT_GF61 - ktailMulGF61(out, in1, in2); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) ktailMul(out, in1, in2); + if (fft.NTT_GF31) ktailMulGF31(out, in1, in2); + if (fft.NTT_GF61) ktailMulGF61(out, in1, in2); } void Gpu::tailMulLow(Buffer& out, Buffer& in1, Buffer& in2) { -#if FFT_FP64 | FFT_FP32 - ktailMulLow(out, in1, in2); -#endif -#if NTT_GF31 - ktailMulLowGF31(out, in1, in2); -#endif -#if NTT_GF61 - ktailMulLowGF61(out, in1, in2); -#endif + if (fft.FFT_FP64 || fft.FFT_FP32) ktailMulLow(out, in1, in2); + if (fft.NTT_GF31) ktailMulLowGF31(out, in1, in2); + if (fft.NTT_GF61) ktailMulLowGF61(out, in1, in2); } void Gpu::carryA(Buffer& out, Buffer& in) { @@ -962,13 +899,12 @@ vector Gpu::readChecked(Buffer& buf) { vector data = readOut(buf); u64 gpuSum = expectedVect[0]; - u64 hostSum = 0; + int even = 1; for (auto it = data.begin(), end = data.end(); it < end; ++it, even = !even) { - if (sizeof(Word) == 4) hostSum += even ? u64(u32(*it)) : (u64(*it) << 32); - if (sizeof(Word) == 8) hostSum += u64(*it); - if (sizeof(Word) == 16) hostSum += u64(*it) + u64((__int128) *it >> 64); + if (fft.WordSize == 4) hostSum += even ? u64(u32(*it)) : (u64(*it) << 32); + if (fft.WordSize == 8) hostSum += u64(*it); } if (hostSum == gpuSum) { @@ -1079,15 +1015,43 @@ void Gpu::logTimeKernels() { profile.reset(); } +vector Gpu::readWords(Buffer &buf) { + // GPU is returning either 4-byte or 8-byte integers. C++ code is expecting 8-byte integers. Handle the "no conversion" case. + if (fft.WordSize == 8) return buf.read(); + // Convert 32-bit GPU Words into 64-bit C++ Words + vector GPUdata = buf.read(); + vector CPUdata; + CPUdata.resize(GPUdata.size() * 2); + for (u32 i = 0; i < GPUdata.size(); ++i) { + CPUdata[2*i] = (i32) GPUdata[i]; + CPUdata[2*i+1] = (GPUdata[i] >> 32); + } + return CPUdata; +} + +void Gpu::writeWords(Buffer& buf, vector &words) { + // GPU is expecting either 4-byte or 8-byte integers. C++ code is using 8-byte integers. Handle the "no conversion" case. + if (fft.WordSize == 8) buf.write(std::move(words)); + // Convert 64-bit C++ Words into 32-bit GPU Words + else { + vector GPUdata; + GPUdata.resize(words.size() / 2); + for (u32 i = 0; i < words.size(); i += 2) { + GPUdata[i/2] = ((i64) words[i+1] << 32) | (i32) words[i]; + } + buf.write(std::move(GPUdata)); + } +} + vector Gpu::readOut(Buffer &buf) { transpOut(bufAux, buf); - return bufAux.read(); + return readWords(bufAux); } void Gpu::writeIn(Buffer& buf, const vector& words) { writeIn(buf, expandBits(words, N, E)); } void Gpu::writeIn(Buffer& buf, vector&& words) { - bufAux.write(std::move(words)); + writeWords(bufAux, words); transpIn(buf, bufAux); } @@ -1235,8 +1199,7 @@ bool Gpu::isEqual(Buffer& in1, Buffer& in2) { u64 Gpu::bufResidue(Buffer &buf) { readResidue(bufSmallOut, buf); - Word words[64]; - bufSmallOut.read(words, 64); + vector words = readWords(bufSmallOut); int carry = 0; for (int i = 0; i < 32; ++i) { diff --git a/src/Gpu.h b/src/Gpu.h index b5b424c8..20a4f713 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -77,20 +77,11 @@ class RoeInfo { double gumbelMiu{}, gumbelBeta{}; }; -#if FFT_FP64 struct Weights { vector weightsConstIF; vector weightsIF; vector bitsCF; }; -#endif -#if FFT_FP32 -struct Weights { - vector weightsConstIF; - vector weightsIF; - vector bitsCF; -}; -#endif class Gpu { Queue* queue; @@ -105,6 +96,7 @@ class Gpu { u32 E; u32 N; + FFTConfig fft; u32 WIDTH; u32 SMALL_H; u32 BIG_H; @@ -117,7 +109,7 @@ class Gpu { KernelCompiler compiler; -#if FFT_FP64 | FFT_FP32 + /* Kernels for FFT_FP64 or FFT_FP32 */ Kernel kfftMidIn; Kernel kfftHin; Kernel ktailSquareZero; @@ -126,9 +118,8 @@ class Gpu { Kernel ktailMulLow; Kernel kfftMidOut; Kernel kfftW; -#endif -#if NTT_GF31 + /* Kernels for NTT_GF31 */ Kernel kfftMidInGF31; Kernel kfftHinGF31; Kernel ktailSquareZeroGF31; @@ -137,9 +128,8 @@ class Gpu { Kernel ktailMulLowGF31; Kernel kfftMidOutGF31; Kernel kfftWGF31; -#endif -#if NTT_GF61 + /* Kernels for NTT_GF61 */ Kernel kfftMidInGF61; Kernel kfftHinGF61; Kernel ktailSquareZeroGF61; @@ -148,8 +138,8 @@ class Gpu { Kernel ktailMulLowGF61; Kernel kfftMidOutGF61; Kernel kfftWGF61; -#endif + /* Kernels dealing with the FP data and product of NTT primes */ Kernel kfftP; Kernel kCarryA; Kernel kCarryAROE; @@ -167,13 +157,13 @@ class Gpu { Kernel readResidue; Kernel kernIsEqual; Kernel sum64; -#if FFT_FP64 + + /* Weird test kernels */ Kernel testTrig; Kernel testFFT4; Kernel testFFT14; Kernel testFFT15; Kernel testFFT; -#endif Kernel testTime; // Kernel testKernel; @@ -181,7 +171,6 @@ class Gpu { // Copy of some -use options needed for Kernel, Trig, and Weights initialization bool tail_single_wide; // TailSquare processes one line at a time bool tail_single_kernel; // TailSquare does not use a separate kernel for line zero - u32 tail_trigs; // 0,1,2. Increasing values use more DP and less memory accesses u32 pad_size; // Pad size in bytes as specified on the command line or config.txt. Maximum value is 512. // Twiddles: trigonometry constant buffers, used in FFTs. @@ -191,19 +180,11 @@ class Gpu { TrigPtr bufTrigM; TrigPtr bufTrigW; - // The weights and the "bigWord bits" depend on the exponent. -#if FFT_FP64 + // Weights and the "bigWord bits" are only needed for FP64 and FP32 FFTs Weights weights; Buffer bufConstWeights; Buffer bufWeights; Buffer bufBits; // bigWord bits aligned for CarryFused/fftP -#endif -#if FFT_FP32 - Weights weights; - Buffer bufConstWeights; - Buffer bufWeights; - Buffer bufBits; // bigWord bits aligned for CarryFused/fftP -#endif // "integer word" buffers. These are "small buffers": N x int. Buffer bufData; // Main int buffer with the words. @@ -254,6 +235,9 @@ class Gpu { void carryFusedMul(Buffer& out, Buffer& in); void carryFusedLL(Buffer& out, Buffer& in); + vector readWords(Buffer &buf); + void writeWords(Buffer& buf, vector &words); + vector readOut(Buffer &buf); void writeIn(Buffer& buf, vector&& words); @@ -323,7 +307,7 @@ class Gpu { u64 dataResidue() { return bufResidue(bufData); } u64 checkResidue() { return bufResidue(bufCheck); } - + bool doCheck(u32 blockSize); void logTimeKernels(); @@ -343,7 +327,7 @@ class Gpu { // A:= A^h * B void expMul(Buffer& A, u64 h, Buffer& B); - + // return A^(2^n) Words expExp2(const Words& A, u32 n); vector> makeBufVector(u32 size); @@ -365,4 +349,5 @@ class Gpu { #define FP32_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(float) / sizeof(double) #define GF31_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(uint) / sizeof(double) #define GF61_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(ulong) / sizeof(double) -#define TOTAL_DATA_SIZE(W,M,H,pad) FFT_FP64 * FP64_DATA_SIZE(W,M,H,pad) + FFT_FP32 * FP32_DATA_SIZE(W,M,H,pad) + NTT_GF31 * GF31_DATA_SIZE(W,M,H,pad) + NTT_GF61 * GF61_DATA_SIZE(W,M,H,pad) +#define TOTAL_DATA_SIZE(fft,W,M,H,pad) fft.FFT_FP64 * FP64_DATA_SIZE(W,M,H,pad) + fft.FFT_FP32 * FP32_DATA_SIZE(W,M,H,pad) + \ + fft.NTT_GF31 * GF31_DATA_SIZE(W,M,H,pad) + fft.NTT_GF61 * GF61_DATA_SIZE(W,M,H,pad) diff --git a/src/Task.cpp b/src/Task.cpp index 7e148181..177128e3 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -83,23 +83,12 @@ OsInfo getOsInfo() { return getOsInfoMinimum(); } #endif -#if FFT_FP64 & NTT_GF31 -#define JSON_FFT_TYPE json("fft-type", "FP64+M31") -#elif FFT_FP64 -#define JSON_FFT_TYPE json("fft-type", "FP64") -#elif FFT_FP32 & NTT_GF31 -#define JSON_FFT_TYPE json("fft-type", "FP32+M31") -#elif FFT_FP32 & NTT_GF61 -#define JSON_FFT_TYPE json("fft-type", "FP32+M61") -#elif NTT_GF31 & NTT_GF61 -#define JSON_FFT_TYPE json("fft-type", "M31+M61") -#elif FFT_FP32 -#define JSON_FFT_TYPE json("fft-type", "FP32") -#elif NTT_GF31 -#define JSON_FFT_TYPE json("fft-type", "M31") -#elif NTT_GF61 -#define JSON_FFT_TYPE json("fft-type", "M61") -#endif +string ffttype(FFTConfig fft) { + return fft.shape.fft_type == FFT64 ? "FP64" : fft.shape.fft_type == FFT3161 ? "M31+M61" : + fft.shape.fft_type == FFT61 ? "M61" : fft.shape.fft_type == FFT3261 ? "FP32+M61" : + fft.shape.fft_type == FFT31 ? "M31" : fft.shape.fft_type == FFT3231 ? "FP32+M31" : + fft.shape.fft_type == FFT32 ? "FP32" : fft.shape.fft_type == FFT6431 ? "FP64+M31" : "unknown"; +} string json(const vector& v) { bool isFirst = true; @@ -163,13 +152,13 @@ void writeResult(u32 instance, u32 E, const char *workType, const string &status } -void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res64, const string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const { +void Task::writeResultPRP(FFTConfig fft, const Args &args, u32 instance, bool isPrime, u64 res64, const string& res2048, u32 nErrors, const fs::path& proofPath) const { vector fields{json("res64", hex(res64)), json("res2048", res2048), json("residue-type", 1), json("errors", vector{json("gerbicz", nErrors)}), - JSON_FFT_TYPE, - json("fft-length", fftSize) + json("fft-type", ffttype(fft)), + json("fft-length", fft.size()) }; // "proof":{"version":1, "power":6, "hashsize":64, "md5":"0123456789ABCDEF"}, @@ -188,10 +177,10 @@ void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res6 writeResult(instance, exponent, "PRP-3", isPrime ? "P" : "C", AID, args, fields); } -void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64, u32 fftSize) const { +void Task::writeResultLL(FFTConfig fft, const Args &args, u32 instance, bool isPrime, u64 res64) const { vector fields{json("res64", hex(res64)), - JSON_FFT_TYPE, - json("fft-length", fftSize), + json("fft-type", ffttype(fft)), + json("fft-length", fft.size()), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this }; @@ -199,14 +188,14 @@ void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64 writeResult(instance, exponent, "LL", isPrime ? "P" : "C", AID, args, fields); } -void Task::writeResultCERT(const Args &args, u32 instance, array hash, u32 squarings, u32 fftSize) const { +void Task::writeResultCERT(FFTConfig fft, const Args &args, u32 instance, array hash, u32 squarings) const { string hexhash = hex(hash[3]) + hex(hash[2]) + hex(hash[1]) + hex(hash[0]); vector fields{json("worktype", "Cert"), json("exponent", exponent), json("sha3-hash", hexhash.c_str()), json("squarings", squarings), - JSON_FFT_TYPE, - json("fft-length", fftSize), + json("fft-type", ffttype(fft)), + json("fft-length", fft.size()), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this }; @@ -239,11 +228,11 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { if (kind == PRP) { auto [tmpIsPrime, res64, nErrors, proofPath, res2048] = gpu->isPrimePRP(*this); isPrime = tmpIsPrime; - writeResultPRP(*shared.args, instance, isPrime, res64, res2048, fft.size(), nErrors, proofPath); + writeResultPRP(fft, *shared.args, instance, isPrime, res64, res2048, nErrors, proofPath); } else { // LL auto [tmpIsPrime, res64] = gpu->isPrimeLL(*this); isPrime = tmpIsPrime; - writeResultLL(*shared.args, instance, isPrime, res64, fft.size()); + writeResultLL(fft, *shared.args, instance, isPrime, res64); } Worktodo::deleteTask(*this, instance); @@ -255,7 +244,7 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { } } else if (kind == CERT) { auto sha256 = gpu->isCERT(*this); - writeResultCERT(*shared.args, instance, sha256, squarings, fft.size()); + writeResultCERT(fft, *shared.args, instance, sha256, squarings); Worktodo::deleteTask(*this, instance); } else { throw "Unexpected task kind " + to_string(kind); diff --git a/src/Task.h b/src/Task.h index 6d870263..95f08024 100644 --- a/src/Task.h +++ b/src/Task.h @@ -5,6 +5,7 @@ #include "Args.h" #include "common.h" #include "GpuCommon.h" +#include "FFTConfig.h" #include @@ -27,7 +28,7 @@ class Task { string verifyPath; // For Verify void execute(GpuCommon shared, Queue* q, u32 instance); - void writeResultPRP(const Args&, u32 instance, bool isPrime, u64 res64, const std::string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const; - void writeResultLL(const Args&, u32 instance, bool isPrime, u64 res64, u32 fftSize) const; - void writeResultCERT(const Args&, u32 instance, array hash, u32 squarings, u32 fftSize) const; + void writeResultPRP(FFTConfig fft, const Args&, u32 instance, bool isPrime, u64 res64, const std::string& res2048, u32 nErrors, const fs::path& proofPath) const; + void writeResultLL(FFTConfig fft, const Args&, u32 instance, bool isPrime, u64 res64) const; + void writeResultCERT(FFTConfig fft, const Args&, u32 instance, array hash, u32 squarings) const; }; diff --git a/src/TrigBufCache.cpp b/src/TrigBufCache.cpp index 1e483b72..03c8e1e6 100644 --- a/src/TrigBufCache.cpp +++ b/src/TrigBufCache.cpp @@ -3,8 +3,6 @@ #include #include "TrigBufCache.h" -#if FFT_FP64 - #define SAVE_ONE_MORE_WIDTH_MUL 0 // I want to make saving the only option -- but rocm optimizer is inexplicably making it slower in carryfused #define SAVE_ONE_MORE_HEIGHT_MUL 1 // In tailSquare this is the fastest option @@ -223,11 +221,13 @@ vector genSmallTrigFP64(u32 size, u32 radix) { } // Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. -vector genSmallTrigComboFP64(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { +vector genSmallTrigComboFP64(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { if (LOG_TRIG_ALLOC) { log("genSmallTrigComboFP64(%u, %u)\n", size, radix); } vector tab = genSmallTrigFP64(size, radix); + u32 tail_trigs = args->value("TAIL_TRIGS", 2); // Default is calculating from scratch, no memory accesses + // From tailSquare pre-calculate some or all of these: T2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); if (tail_trigs == 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. u32 height = size; @@ -279,24 +279,13 @@ vector genMiddleTrigFP64(u32 smallH, u32 middle, u32 width) { return tab; } -#endif - /**************************************************************************/ /* Similar to above, but for an FFT based on floats */ /**************************************************************************/ -#if FFT_FP32 - -#define _USE_MATH_DEFINES -#include - -#ifndef M_PI -#define M_PI 3.1415926535897931 -#endif - // For small angles, return "fancy" cos - 1 for increased precision -float2 root1Fancy(u32 N, u32 k) { +float2 root1FancyFP32(u32 N, u32 k) { assert(!(N&7)); assert(k < N); assert(k < N/4); @@ -329,16 +318,16 @@ static float2 roundTrig(double lc, double ls) { } // Returns the primitive root of unity of order N, to the power k. -float2 root1(u32 N, u32 k) { +float2 root1FP32(u32 N, u32 k) { assert(k < N); if (k >= N/2) { - auto [c, s] = root1(N, k - N/2); + auto [c, s] = root1FP32(N, k - N/2); return {-c, -s}; } else if (k > N/4) { - auto [c, s] = root1(N, N/2 - k); + auto [c, s] = root1FP32(N, N/2 - k); return {-c, s}; } else if (k > N/8) { - auto [c, s] = root1(N, N/4 - k); + auto [c, s] = root1FP32(N, N/4 - k); return {s, c}; } else { assert(k <= N/8); @@ -355,7 +344,7 @@ vector genSmallTrigFP32(u32 size, u32 radix) { // old fft_WIDTH and fft_HEIGHT for (u32 line = 1; line < radix; ++line) { for (u32 col = 0; col < WG; ++col) { - tab.push_back(radix / line >= 8 ? root1Fancy(size, col * line) : root1(size, col * line)); + tab.push_back(radix / line >= 8 ? root1FancyFP32(size, col * line) : root1FP32(size, col * line)); } } tab.resize(size); @@ -363,20 +352,22 @@ vector genSmallTrigFP32(u32 size, u32 radix) { } // Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. -vector genSmallTrigComboFP32(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { +vector genSmallTrigComboFP32(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { vector tab = genSmallTrigFP32(size, radix); + u32 tail_trigs = args->value("TAIL_TRIGS32", 2); // Default is calculating from scratch, no memory accesses + // From tailSquare pre-calculate some or all of these: F2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); if (tail_trigs == 1) { // Some trig values in memory, some are computed with a complex multiply. u32 height = size; // Output line 0 trig values to be read by every u,v pair of lines for (u32 me = 0; me < height / radix; ++me) { - tab.push_back(root1(width * middle * height, width * middle * me)); + tab.push_back(root1FP32(width * middle * height, width * middle * me)); } // Output the one or two F2 multipliers to be read by one u,v pair of lines for (u32 line = 0; line <= width * middle / 2; ++line) { - tab.push_back(root1Fancy(width * middle * height, line)); - if (!tail_single_wide) tab.push_back(root1Fancy(width * middle * height, line ? width * middle - line : width * middle / 2)); + tab.push_back(root1FancyFP32(width * middle * height, line)); + if (!tail_single_wide) tab.push_back(root1FancyFP32(width * middle * height, line ? width * middle - line : width * middle / 2)); } } if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with lousy FP performance? @@ -385,7 +376,7 @@ vector genSmallTrigComboFP32(u32 width, u32 middle, u32 size, u32 radix, for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); for (u32 me = 0; me < height / radix; ++me) { - tab.push_back(root1(width * middle * height, line + width * middle * me)); + tab.push_back(root1FP32(width * middle * height, line + width * middle * me)); } } } @@ -394,80 +385,72 @@ vector genSmallTrigComboFP32(u32 width, u32 middle, u32 size, u32 radix, return tab; } -// starting from a MIDDLE of 5 we consider angles in [0, 2Pi/MIDDLE] as worth storing with the -// cos-1 "fancy" trick. -#define SHARP_MIDDLE 5 - vector genMiddleTrigFP32(u32 smallH, u32 middle, u32 width) { vector tab; if (middle == 1) { tab.resize(1); } else { if (middle < SHARP_MIDDLE) { - for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(smallH * middle, k)); } - for (u32 k = 0; k < width; ++k) { tab.push_back(root1(middle * width, k)); } - for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FP32(smallH * middle, k)); } + for (u32 k = 0; k < width; ++k) { tab.push_back(root1FP32(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FP32(width * middle * smallH, k)); } } else { - for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1Fancy(smallH * middle, k)); } - for (u32 k = 0; k < width; ++k) { tab.push_back(root1Fancy(middle * width, k)); } - for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FancyFP32(smallH * middle, k)); } + for (u32 k = 0; k < width; ++k) { tab.push_back(root1FancyFP32(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FP32(width * middle * smallH, k)); } } } return tab; } -#endif - /**************************************************************************/ /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ -#if NTT_GF31 - // Z31 and GF31 code copied from Yves Gallot's mersenne2 program // Z/{2^31 - 1}Z: the prime field of order p = 2^31 - 1 class Z31 { private: - static const uint32_t _p = (uint32_t(1) << 31) - 1; - uint32_t _n; // 0 <= n < p - - static uint32_t _add(const uint32_t a, const uint32_t b) - { - const uint32_t t = a + b; - return t - ((t >= _p) ? _p : 0); - } - - static uint32_t _sub(const uint32_t a, const uint32_t b) - { - const uint32_t t = a - b; - return t + ((a < b) ? _p : 0); - } - - static uint32_t _mul(const uint32_t a, const uint32_t b) - { - const uint64_t t = a * uint64_t(b); - return _add(uint32_t(t) & _p, uint32_t(t >> 31)); - } + static const uint32_t _p = (uint32_t(1) << 31) - 1; + uint32_t _n; // 0 <= n < p + + static uint32_t _add(const uint32_t a, const uint32_t b) + { + const uint32_t t = a + b; + return t - ((t >= _p) ? _p : 0); + } + + static uint32_t _sub(const uint32_t a, const uint32_t b) + { + const uint32_t t = a - b; + return t + ((a < b) ? _p : 0); + } + + static uint32_t _mul(const uint32_t a, const uint32_t b) + { + const uint64_t t = a * uint64_t(b); + return _add(uint32_t(t) & _p, uint32_t(t >> 31)); + } public: - Z31() {} - explicit Z31(const uint32_t n) : _n(n) {} + Z31() {} + explicit Z31(const uint32_t n) : _n(n) {} - uint32_t get() const { return _n; } + uint32_t get() const { return _n; } - bool operator!=(const Z31 & rhs) const { return (_n != rhs._n); } + bool operator!=(const Z31 & rhs) const { return (_n != rhs._n); } - // Z31 neg() const { return Z31((_n == 0) ? 0 : _p - _n); } - // Z31 half() const { return Z31(((_n % 2 == 0) ? _n : (_n + _p)) / 2); } + // Z31 neg() const { return Z31((_n == 0) ? 0 : _p - _n); } + // Z31 half() const { return Z31(((_n % 2 == 0) ? _n : (_n + _p)) / 2); } - Z31 operator+(const Z31 & rhs) const { return Z31(_add(_n, rhs._n)); } - Z31 operator-(const Z31 & rhs) const { return Z31(_sub(_n, rhs._n)); } - Z31 operator*(const Z31 & rhs) const { return Z31(_mul(_n, rhs._n)); } + Z31 operator+(const Z31 & rhs) const { return Z31(_add(_n, rhs._n)); } + Z31 operator-(const Z31 & rhs) const { return Z31(_sub(_n, rhs._n)); } + Z31 operator*(const Z31 & rhs) const { return Z31(_mul(_n, rhs._n)); } - Z31 sqr() const { return Z31(_mul(_n, _n)); } + Z31 sqr() const { return Z31(_mul(_n, _n)); } }; @@ -475,35 +458,35 @@ class Z31 class GF31 { private: - Z31 _s0, _s1; - // a primitive root of order 2^32 which is a root of (0, 1). - static const uint64_t _h_order = uint64_t(1) << 32; - static const uint32_t _h_0 = 7735u, _h_1 = 748621u; + Z31 _s0, _s1; + // a primitive root of order 2^32 which is a root of (0, 1). + static const uint64_t _h_order = uint64_t(1) << 32; + static const uint32_t _h_0 = 7735u, _h_1 = 748621u; public: - GF31() {} - explicit GF31(const Z31 & s0, const Z31 & s1) : _s0(s0), _s1(s1) {} - explicit GF31(const uint32_t n0, const uint32_t n1) : _s0(n0), _s1(n1) {} + GF31() {} + explicit GF31(const Z31 & s0, const Z31 & s1) : _s0(s0), _s1(s1) {} + explicit GF31(const uint32_t n0, const uint32_t n1) : _s0(n0), _s1(n1) {} - const Z31 & s0() const { return _s0; } - const Z31 & s1() const { return _s1; } + const Z31 & s0() const { return _s0; } + const Z31 & s1() const { return _s1; } - GF31 operator+(const GF31 & rhs) const { return GF31(_s0 + rhs._s0, _s1 + rhs._s1); } - GF31 operator-(const GF31 & rhs) const { return GF31(_s0 - rhs._s0, _s1 - rhs._s1); } + GF31 operator+(const GF31 & rhs) const { return GF31(_s0 + rhs._s0, _s1 + rhs._s1); } + GF31 operator-(const GF31 & rhs) const { return GF31(_s0 - rhs._s0, _s1 - rhs._s1); } - GF31 sqr() const { const Z31 t = _s0 * _s1; return GF31(_s0.sqr() - _s1.sqr(), t + t); } - GF31 mul(const GF31 & rhs) const { return GF31(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } + GF31 sqr() const { const Z31 t = _s0 * _s1; return GF31(_s0.sqr() - _s1.sqr(), t + t); } + GF31 mul(const GF31 & rhs) const { return GF31(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } - GF31 pow(const uint64_t e) const - { - if (e == 0) return GF31(1u, 0u); - GF31 r = GF31(1u, 0u), y = *this; - for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } - return r.mul(y); - } + GF31 pow(const uint64_t e) const + { + if (e == 0) return GF31(1u, 0u); + GF31 r = GF31(1u, 0u), y = *this; + for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } + return r.mul(y); + } - static const GF31 root_one(const size_t n) { return GF31(Z31(_h_0), Z31(_h_1)).pow(_h_order / n); } - static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 30) / n) % 31); } + static const GF31 root_one(const size_t n) { return GF31(Z31(_h_0), Z31(_h_1)).pow(_h_order / n); } + static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 30) / n) % 31); } }; // Returns the primitive root of unity of order N, to the power k. @@ -521,7 +504,6 @@ vector genSmallTrigGF31(u32 size, u32 radix) { u32 WG = size / radix; vector tab; -// old fft_WIDTH and fft_HEIGHT GF31 root1size = GF31::root_one(size); for (u32 line = 1; line < radix; ++line) { for (u32 col = 0; col < WG; ++col) { @@ -533,9 +515,11 @@ vector genSmallTrigGF31(u32 size, u32 radix) { } // Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. -vector genSmallTrigComboGF31(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { +vector genSmallTrigComboGF31(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { vector tab = genSmallTrigGF31(size, radix); + u32 tail_trigs = args->value("TAIL_TRIGS31", 0); // Default is reading all trigs from memory + // From tailSquareGF31 pre-calculate some or all of these: GF31 trig = slowTrigGF31(line + H * lowMe, ND / NH * 2); u32 height = size; GF31 root1wmh = GF31::root_one(width * middle * height); @@ -581,94 +565,90 @@ vector genMiddleTrigGF31(u32 smallH, u32 middle, u32 width) { return tab; } -#endif - /**************************************************************************/ /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ -#if NTT_GF61 - // Z61 and GF61 code copied from Yves Gallot's mersenne2 program // Z/{2^61 - 1}Z: the prime field of order p = 2^61 - 1 class Z61 { private: - static const uint64_t _p = (uint64_t(1) << 61) - 1; - uint64_t _n; // 0 <= n < p - - static uint64_t _add(const uint64_t a, const uint64_t b) - { - const uint64_t t = a + b; - return t - ((t >= _p) ? _p : 0); - } - - static uint64_t _sub(const uint64_t a, const uint64_t b) - { - const uint64_t t = a - b; - return t + ((a < b) ? _p : 0); - } - - static uint64_t _mul(const uint64_t a, const uint64_t b) - { - const __uint128_t t = a * __uint128_t(b); - const uint64_t lo = uint64_t(t), hi = uint64_t(t >> 64); - const uint64_t lo61 = lo & _p, hi61 = (lo >> 61) | (hi << 3); - return _add(lo61, hi61); - } + static const uint64_t _p = (uint64_t(1) << 61) - 1; + uint64_t _n; // 0 <= n < p + + static uint64_t _add(const uint64_t a, const uint64_t b) + { + const uint64_t t = a + b; + return t - ((t >= _p) ? _p : 0); + } + + static uint64_t _sub(const uint64_t a, const uint64_t b) + { + const uint64_t t = a - b; + return t + ((a < b) ? _p : 0); + } + + static uint64_t _mul(const uint64_t a, const uint64_t b) + { + const __uint128_t t = a * __uint128_t(b); + const uint64_t lo = uint64_t(t), hi = uint64_t(t >> 64); + const uint64_t lo61 = lo & _p, hi61 = (lo >> 61) | (hi << 3); + return _add(lo61, hi61); + } public: - Z61() {} - explicit Z61(const uint64_t n) : _n(n) {} + Z61() {} + explicit Z61(const uint64_t n) : _n(n) {} - uint64_t get() const { return _n; } + uint64_t get() const { return _n; } - bool operator!=(const Z61 & rhs) const { return (_n != rhs._n); } + bool operator!=(const Z61 & rhs) const { return (_n != rhs._n); } - Z61 operator+(const Z61 & rhs) const { return Z61(_add(_n, rhs._n)); } - Z61 operator-(const Z61 & rhs) const { return Z61(_sub(_n, rhs._n)); } - Z61 operator*(const Z61 & rhs) const { return Z61(_mul(_n, rhs._n)); } + Z61 operator+(const Z61 & rhs) const { return Z61(_add(_n, rhs._n)); } + Z61 operator-(const Z61 & rhs) const { return Z61(_sub(_n, rhs._n)); } + Z61 operator*(const Z61 & rhs) const { return Z61(_mul(_n, rhs._n)); } - Z61 sqr() const { return Z61(_mul(_n, _n)); } + Z61 sqr() const { return Z61(_mul(_n, _n)); } }; // GF((2^61 - 1)^2): the prime field of order p^2, p = 2^61 - 1 class GF61 { private: - Z61 _s0, _s1; - // Primitive root of order 2^62 which is a root of (0, 1). This root corresponds to 2*pi*i*j/N in FFTs. PRPLL FFTs use this root. Thanks, Yves! - static const uint64_t _h_0 = 264036120304204ull, _h_1 = 4677669021635377ull; - // Primitive root of order 2^62 which is a root of (0, -1). This root corresponds to -2*pi*i*j/N in FFTs. - //static const uint64_t _h_0 = 481139922016222ull, _h_1 = 814659809902011ull; - static const uint64_t _h_order = uint64_t(1) << 62; + Z61 _s0, _s1; + // Primitive root of order 2^62 which is a root of (0, 1). This root corresponds to 2*pi*i*j/N in FFTs. PRPLL FFTs use this root. Thanks, Yves! + static const uint64_t _h_0 = 264036120304204ull, _h_1 = 4677669021635377ull; + // Primitive root of order 2^62 which is a root of (0, -1). This root corresponds to -2*pi*i*j/N in FFTs. + //static const uint64_t _h_0 = 481139922016222ull, _h_1 = 814659809902011ull; + static const uint64_t _h_order = uint64_t(1) << 62; public: - GF61() {} - explicit GF61(const Z61 & s0, const Z61 & s1) : _s0(s0), _s1(s1) {} - explicit GF61(const uint64_t n0, const uint64_t n1) : _s0(n0), _s1(n1) {} + GF61() {} + explicit GF61(const Z61 & s0, const Z61 & s1) : _s0(s0), _s1(s1) {} + explicit GF61(const uint64_t n0, const uint64_t n1) : _s0(n0), _s1(n1) {} - const Z61 & s0() const { return _s0; } - const Z61 & s1() const { return _s1; } + const Z61 & s0() const { return _s0; } + const Z61 & s1() const { return _s1; } - GF61 operator+(const GF61 & rhs) const { return GF61(_s0 + rhs._s0, _s1 + rhs._s1); } - GF61 operator-(const GF61 & rhs) const { return GF61(_s0 - rhs._s0, _s1 - rhs._s1); } + GF61 operator+(const GF61 & rhs) const { return GF61(_s0 + rhs._s0, _s1 + rhs._s1); } + GF61 operator-(const GF61 & rhs) const { return GF61(_s0 - rhs._s0, _s1 - rhs._s1); } - GF61 sqr() const { const Z61 t = _s0 * _s1; return GF61(_s0.sqr() - _s1.sqr(), t + t); } - GF61 mul(const GF61 & rhs) const { return GF61(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } + GF61 sqr() const { const Z61 t = _s0 * _s1; return GF61(_s0.sqr() - _s1.sqr(), t + t); } + GF61 mul(const GF61 & rhs) const { return GF61(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } - GF61 pow(const uint64_t e) const - { - if (e == 0) return GF61(1u, 0u); - GF61 r = GF61(1u, 0u), y = *this; - for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } - return r.mul(y); - } + GF61 pow(const uint64_t e) const + { + if (e == 0) return GF61(1u, 0u); + GF61 r = GF61(1u, 0u), y = *this; + for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } + return r.mul(y); + } - static const GF61 root_one(const size_t n) { return GF61(Z61(_h_0), Z61(_h_1)).pow(_h_order / n); } - static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 60) / n) % 61); } + static const GF61 root_one(const size_t n) { return GF61(Z61(_h_0), Z61(_h_1)).pow(_h_order / n); } + static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 60) / n) % 61); } }; // Returns the primitive root of unity of order N, to the power k. @@ -686,7 +666,6 @@ vector genSmallTrigGF61(u32 size, u32 radix) { u32 WG = size / radix; vector tab; -// old fft_WIDTH and fft_HEIGHT GF61 root1size = GF61::root_one(size); for (u32 line = 1; line < radix; ++line) { for (u32 col = 0; col < WG; ++col) { @@ -698,9 +677,11 @@ vector genSmallTrigGF61(u32 size, u32 radix) { } // Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. -vector genSmallTrigComboGF61(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { +vector genSmallTrigComboGF61(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { vector tab = genSmallTrigGF61(size, radix); + u32 tail_trigs = args->value("TAIL_TRIGS61", 0); // Default is reading all trigs from memory + // From tailSquareGF61 pre-calculate some or all of these: GF61 trig = slowTrigGF61(line + H * lowMe, ND / NH * 2); u32 height = size; GF61 root1wmh = GF61::root_one(width * middle * height); @@ -746,126 +727,124 @@ vector genMiddleTrigGF61(u32 smallH, u32 middle, u32 width) { return tab; } -#endif - /**********************************************************/ /* Build all the needed trig values into one big buffer */ /**********************************************************/ -vector genSmallTrig(u32 size, u32 radix) { +vector genSmallTrig(FFTConfig fft, u32 size, u32 radix) { vector tab; u32 tabsize; -#if FFT_FP64 - tab = genSmallTrigFP64(size, radix); - tab.resize(SMALLTRIG_FP64_SIZE(size, 0, 0, 0)); -#endif + if (fft.FFT_FP64) { + tab = genSmallTrigFP64(size, radix); + tab.resize(SMALLTRIG_FP64_SIZE(size, 0, 0, 0)); + } -#if FFT_FP32 - vector tab1 = genSmallTrigFP32(size, radix); - tab1.resize(SMALLTRIG_FP32_SIZE(size, 0, 0, 0)); - // Append tab1 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab1.size() / 2); - memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); -#endif + if (fft.FFT_FP32) { + vector tab1 = genSmallTrigFP32(size, radix); + tab1.resize(SMALLTRIG_FP32_SIZE(size, 0, 0, 0)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); + } -#if NTT_GF31 - vector tab2 = genSmallTrigGF31(size, radix); - tab2.resize(SMALLTRIG_GF31_SIZE(size, 0, 0, 0)); - // Append tab2 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab2.size() / 2); - memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); -#endif + if (fft.NTT_GF31) { + vector tab2 = genSmallTrigGF31(size, radix); + tab2.resize(SMALLTRIG_GF31_SIZE(size, 0, 0, 0)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); + } -#if NTT_GF61 - vector tab3 = genSmallTrigGF61(size, radix); - tab3.resize(SMALLTRIG_GF61_SIZE(size, 0, 0, 0)); - // Append tab3 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab3.size()); - memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); -#endif + if (fft.NTT_GF61) { + vector tab3 = genSmallTrigGF61(size, radix); + tab3.resize(SMALLTRIG_GF61_SIZE(size, 0, 0, 0)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); + } return tab; } -vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { +vector genSmallTrigCombo(Args *args, FFTConfig fft, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { vector tab; u32 tabsize; -#if FFT_FP64 - tab = genSmallTrigComboFP64(width, middle, size, radix, tail_single_wide, tail_trigs); - tab.resize(SMALLTRIGCOMBO_FP64_SIZE(width, middle, size, radix)); -#endif + if (fft.FFT_FP64) { + tab = genSmallTrigComboFP64(args, width, middle, size, radix, tail_single_wide); + tab.resize(SMALLTRIGCOMBO_FP64_SIZE(width, middle, size, radix)); + } -#if FFT_FP32 - vector tab1 = genSmallTrigComboFP32(width, middle, size, radix, tail_single_wide, tail_trigs); - tab1.resize(SMALLTRIGCOMBO_FP32_SIZE(width, middle, size, radix)); - // Append tab1 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab1.size() / 2); - memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); -#endif + if (fft.FFT_FP32) { + vector tab1 = genSmallTrigComboFP32(args, width, middle, size, radix, tail_single_wide); + tab1.resize(SMALLTRIGCOMBO_FP32_SIZE(width, middle, size, radix)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); + } -#if NTT_GF31 - vector tab2 = genSmallTrigComboGF31(width, middle, size, radix, tail_single_wide, tail_trigs); - tab2.resize(SMALLTRIGCOMBO_GF31_SIZE(width, middle, size, radix)); - // Append tab2 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab2.size() / 2); - memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); -#endif + if (fft.NTT_GF31) { + vector tab2 = genSmallTrigComboGF31(args, width, middle, size, radix, tail_single_wide); + tab2.resize(SMALLTRIGCOMBO_GF31_SIZE(width, middle, size, radix)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); + } -#if NTT_GF61 - vector tab3 = genSmallTrigComboGF61(width, middle, size, radix, tail_single_wide, tail_trigs); - tab3.resize(SMALLTRIGCOMBO_GF61_SIZE(width, middle, size, radix)); - // Append tab3 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab3.size()); - memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); -#endif + if (fft.NTT_GF61) { + vector tab3 = genSmallTrigComboGF61(args, width, middle, size, radix, tail_single_wide); + tab3.resize(SMALLTRIGCOMBO_GF61_SIZE(width, middle, size, radix)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); + } return tab; } -vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { +vector genMiddleTrig(FFTConfig fft, u32 smallH, u32 middle, u32 width) { vector tab; u32 tabsize; -#if FFT_FP64 - tab = genMiddleTrigFP64(smallH, middle, width); - tab.resize(MIDDLETRIG_FP64_SIZE(width, middle, smallH)); -#endif + if (fft.FFT_FP64) { + tab = genMiddleTrigFP64(smallH, middle, width); + tab.resize(MIDDLETRIG_FP64_SIZE(width, middle, smallH)); + } -#if FFT_FP32 - vector tab1 = genMiddleTrigFP32(smallH, middle, width); - tab1.resize(MIDDLETRIG_FP32_SIZE(width, middle, smallH)); - // Append tab1 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab1.size() / 2); - memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); -#endif + if (fft.FFT_FP32) { + vector tab1 = genMiddleTrigFP32(smallH, middle, width); + tab1.resize(MIDDLETRIG_FP32_SIZE(width, middle, smallH)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); + } -#if NTT_GF31 - vector tab2 = genMiddleTrigGF31(smallH, middle, width); - tab2.resize(MIDDLETRIG_GF31_SIZE(width, middle, smallH)); - // Append tab2 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab2.size() / 2); - memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); -#endif + if (fft.NTT_GF31) { + vector tab2 = genMiddleTrigGF31(smallH, middle, width); + tab2.resize(MIDDLETRIG_GF31_SIZE(width, middle, smallH)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); + } -#if NTT_GF61 - vector tab3 = genMiddleTrigGF61(smallH, middle, width); - tab3.resize(MIDDLETRIG_GF61_SIZE(width, middle, smallH)); - // Append tab3 to tab - tabsize = tab.size(); - tab.resize(tabsize + tab3.size()); - memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); -#endif + if (fft.NTT_GF61) { + vector tab3 = genMiddleTrigGF61(smallH, middle, width); + tab3.resize(MIDDLETRIG_GF61_SIZE(width, middle, smallH)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); + } return tab; } @@ -875,61 +854,77 @@ vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { /* Code to manage a cache of trigBuffers */ /********************************************************/ +#define make_key_part(b,tt,b31,tt31,b32,tt32,b61,tt61,tk) ((((((((b+tt) << 2) + b31+tt31) << 2) + b32+tt32) << 2) + b61+tt61) << 2) + tk + TrigBufCache::~TrigBufCache() = default; -TrigPtr TrigBufCache::smallTrig(u32 width, u32 nW, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs) { +TrigPtr TrigBufCache::smallTrig(Args *args, FFTConfig fft, u32 width, u32 nW, u32 middle, u32 height, u32 nH, bool tail_single_wide) { lock_guard lock{mut}; auto& m = small; TrigPtr p{}; - // See if there is an existing smallTrigCombo that we can return (using only as subset of the data) + u32 tail_trigs = args->value("TAIL_TRIGS", 2); // Default is calculating FP64 trigs from scratch, no memory accesses + u32 tail_trigs31 = args->value("TAIL_TRIGS31", 2); // Default is reading GF31 trigs from memory + u32 tail_trigs32 = args->value("TAIL_TRIGS32", 2); // Default is calculating FP32 trigs from scratch, no memory accesses + u32 tail_trigs61 = args->value("TAIL_TRIGS61", 2); // Default is reading GF61 trigs from memory + u32 key_part = make_key_part(fft.FFT_FP64, tail_trigs, fft.NTT_GF31, tail_trigs31, fft.FFT_FP32, tail_trigs32, fft.NTT_GF61, tail_trigs61, tail_single_wide); + + // See if there is an existing smallTrigCombo that we can return (using only a subset of the data) // In theory, we could match any smallTrigCombo where width matches. However, SMALLTRIG_GF31_SIZE wouldn't be able to figure out the size. // In practice, those cases will likely never arise. if (width == height && nW == nH) { - decay_t::key_type key{height, nH, width, middle, tail_single_wide, tail_trigs}; + decay_t::key_type key{height, nH, width, middle, key_part}; auto it = m.find(key); if (it != m.end() && (p = it->second.lock())) return p; } // See if there is an existing non-combo smallTrig that we can return - decay_t::key_type key{width, nW, 0, 0, 0, 0}; + decay_t::key_type key{width, nW, 0, 0, key_part}; auto it = m.find(key); if (it != m.end() && (p = it->second.lock())) return p; // Create a new non-combo - p = make_shared(context, genSmallTrig(width, nW)); + p = make_shared(context, genSmallTrig(fft, width, nW)); m[key] = p; smallCache.add(p); return p; } -TrigPtr TrigBufCache::smallTrigCombo(u32 width, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs) { - if (tail_trigs == 2 && !NTT_GF31 && !NTT_GF61) // No pre-computed trig values. We might be able to share this trig table with fft_WIDTH - return smallTrig(height, nH, middle, height, nH, variant, tail_single_wide, tail_trigs); +TrigPtr TrigBufCache::smallTrigCombo(Args *args, FFTConfig fft, u32 width, u32 middle, u32 height, u32 nH, bool tail_single_wide) { + u32 tail_trigs = args->value("TAIL_TRIGS", 2); // Default is calculating FP64 trigs from scratch, no memory accesses + u32 tail_trigs31 = args->value("TAIL_TRIGS31", 2); // Default is reading GF31 trigs from memory + u32 tail_trigs32 = args->value("TAIL_TRIGS32", 2); // Default is calculating FP32 trigs from scratch, no memory accesses + u32 tail_trigs61 = args->value("TAIL_TRIGS61", 2); // Default is reading GF61 trigs from memory + u32 key_part = make_key_part(fft.FFT_FP64, tail_trigs, fft.NTT_GF31, tail_trigs31, fft.FFT_FP32, tail_trigs32, fft.NTT_GF61, tail_trigs61, tail_single_wide); + + // If there are no pre-computed trig values we might be able to share this trig table with fft_WIDTH + if (((tail_trigs == 2 && fft.FFT_FP64) || (tail_trigs32 == 2 && fft.FFT_FP32)) && !fft.NTT_GF31 && !fft.NTT_GF61) + return smallTrig(args, fft, height, nH, middle, height, nH, tail_single_wide); lock_guard lock{mut}; auto& m = small; - decay_t::key_type key{height, nH, width, middle, tail_single_wide, tail_trigs}; + decay_t::key_type key{height, nH, width, middle, key_part}; TrigPtr p{}; auto it = m.find(key); if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genSmallTrigCombo(width, middle, height, nH, tail_single_wide, tail_trigs)); + p = make_shared(context, genSmallTrigCombo(args, fft, width, middle, height, nH, tail_single_wide)); m[key] = p; smallCache.add(p); } return p; } -TrigPtr TrigBufCache::middleTrig(u32 SMALL_H, u32 MIDDLE, u32 width) { +TrigPtr TrigBufCache::middleTrig(Args *args, FFTConfig fft, u32 SMALL_H, u32 MIDDLE, u32 width) { lock_guard lock{mut}; auto& m = middle; - decay_t::key_type key{SMALL_H, MIDDLE, width}; + u32 key_part = make_key_part(fft.FFT_FP64, 0, fft.NTT_GF31, 0, fft.FFT_FP32, 0, fft.NTT_GF61, 0, 0); + decay_t::key_type key{SMALL_H, MIDDLE, width, key_part}; TrigPtr p{}; auto it = m.find(key); if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genMiddleTrig(SMALL_H, MIDDLE, width)); + p = make_shared(context, genMiddleTrig(fft, SMALL_H, MIDDLE, width)); m[key] = p; middleCache.add(p); } diff --git a/src/TrigBufCache.h b/src/TrigBufCache.h index c1a6427c..d5e5317a 100644 --- a/src/TrigBufCache.h +++ b/src/TrigBufCache.h @@ -3,6 +3,7 @@ #pragma once #include "Buffer.h" +#include "FFTConfig.h" #include @@ -26,8 +27,8 @@ class TrigBufCache { const Context* context; std::mutex mut; - std::map, TrigPtr::weak_type> small; - std::map, TrigPtr::weak_type> middle; + std::map, TrigPtr::weak_type> small; + std::map, TrigPtr::weak_type> middle; // The shared-pointers below keep the most recent set of buffers alive even without any Gpu instance // referencing them. This allows a single worker to delete & re-create the Gpu instance and still reuse the buffers. @@ -41,29 +42,20 @@ class TrigBufCache { ~TrigBufCache(); - TrigPtr smallTrigCombo(u32 width, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs); - TrigPtr middleTrig(u32 SMALL_H, u32 MIDDLE, u32 W); - TrigPtr smallTrig(u32 width, u32 nW, u32 middle, u32 height, u32 nH, u32 variant, bool tail_single_wide, u32 tail_trigs); + TrigPtr smallTrigCombo(Args *args, FFTConfig fft, u32 width, u32 middle, u32 height, u32 nH, bool tail_single_wide); + TrigPtr middleTrig(Args *args, FFTConfig fft, u32 SMALL_H, u32 MIDDLE, u32 W); + TrigPtr smallTrig(Args *args, FFTConfig fft, u32 width, u32 nW, u32 middle, u32 height, u32 nH, bool tail_single_wide); }; -#if FFT_FP64 double2 root1Fancy(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision double2 root1(u32 N, u32 k); -#endif -#if FFT_FP32 -float2 root1Fancy(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision -float2 root1(u32 N, u32 k); -#endif +float2 root1FancyFP32(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision +float2 root1FP32(u32 N, u32 k); -#if NTT_GF31 uint2 root1GF31(u32 N, u32 k); -#endif - -#if NTT_GF61 ulong2 root1GF61(u32 N, u32 k); -#endif // Compute the size of the largest possible trig buffer given width, middle, height (in number of float2 values) #define SMALLTRIG_FP64_SIZE(W,M,H,nH) (W != H || H == 0 ? W * 5 : SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH)) // See genSmallTrigFP64 diff --git a/src/cl/base.cl b/src/cl/base.cl index 0872ef42..24986309 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -103,9 +103,22 @@ G_H "group height" == SMALL_HEIGHT / NH #error FFT_VARIANT_H must be between 0 and 2 #endif +#if !defined(BIGLIT) +#define BIGLIT 1 +#endif + #if !defined(TABMUL_CHAIN) #define TABMUL_CHAIN 0 #endif +#if !defined(TABMUL_CHAIN31) +#define TABMUL_CHAIN31 0 +#endif +#if !defined(TABMUL_CHAIN32) +#define TABMUL_CHAIN32 0 +#endif +#if !defined(TABMUL_CHAIN61) +#define TABMUL_CHAIN61 0 +#endif #if !defined(MIDDLE_CHAIN) #define MIDDLE_CHAIN 0 @@ -156,26 +169,6 @@ typedef ulong u64; typedef __int128 i128; typedef unsigned __int128 u128; -// Typedefs and defines for supporting hybrid FFTs -#if !defined(FFT_FP64) -#define FFT_FP64 1 -#endif -#if !defined(FFT_FP32) -#define FFT_FP32 0 -#endif -#if !defined(NTT_GF31) -#define NTT_GF31 0 -#endif -#if !defined(NTT_GF61) -#define NTT_GF61 0 -#endif -#if !defined(NTT_NCW) -#define NTT_NCW 0 -#endif -#if NTT_NCW -#error Nick Craig-Woods NTT prime is not supported now -#endif - // Data types for data stored in FFTs and NTTs during the transform typedef double T; // For historical reasons, classic FFTs using doubles call their data T and T2. typedef double2 T2; // A complex value using doubles in a classic FFT. @@ -189,18 +182,22 @@ typedef ulong2 GF61; // A complex value using two Z61s. For a GF(M61^2) //typedef ulong2 NCW2; // A complex value using NCWs. For a Nick Craig-Wood's insipred NTT using prime 2^64 - 2^32 + 1. // Typedefs for "combo" FFT/NTTs (multiple NTT primes or hybrid FFT/NTT). +#define COMBO_FFT (FFT_FP64 + FFT_FP32 + NTT_GF31 + NTT_GF61 > 1) +// Sanity check for supported FFT/NTT +#if (FFT_FP64 & NTT_GF31 & !FFT_FP32 & !NTT_GF61) | (NTT_GF31 & NTT_GF61 & !FFT_FP64 & !FFT_FP32) | (FFT_FP32 & NTT_GF61 & !FFT_FP64 & !NTT_GF31) +#elif !COMBO_FFT | (FFT_FP32 & NTT_GF31 & !FFT_FP64 & !NTT_GF61) +#else +error - unsupported FFT/NTT combination +#endif // Word and Word2 define the data type for FFT integers passed between the CPU and GPU. -#define COMBO_FFT (FFT_FP64 + FFT_FP32 + NTT_GF31 + NTT_GF61 + NTT_NCW > 1) -#if (FFT_FP64 & NTT_GF31 & !FFT_FP32 & !NTT_GF61 & !NTT_NCW) | (NTT_GF31 & NTT_GF61 & !FFT_FP64 & !FFT_FP32 & !NTT_NCW) | (FFT_FP32 & NTT_GF61 & !FFT_FP64 & !NTT_GF31 & !NTT_NCW) -#define WordSize 8 +#if WordSize == 8 typedef i64 Word; typedef long2 Word2; -#elif !COMBO_FFT | (FFT_FP32 & NTT_GF31 & !FFT_FP64 & !NTT_GF61 & !NTT_NCW) -#define WordSize 4 +#elif WordSize == 4 typedef i32 Word; typedef int2 Word2; #else -error - unsupported FFT/NTT combination +error - unsupported integer WordSize #endif // Routine to create a pair diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index 3c80008e..29ddfa64 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -573,14 +573,14 @@ void OVERLOAD tabMul(u32 WG, TrigFP32 trig, F2 *u, u32 n, u32 f, u32 me) { // This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. - if (TABMUL_CHAIN) { + if (TABMUL_CHAIN32) { chainMul (n, u, trig[p], 0); return; } // Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. - if (!TABMUL_CHAIN) { + if (!TABMUL_CHAIN32) { if (n >= 8) { u[1] = cmulFancy(u[1], trig[p]); } else { @@ -688,14 +688,14 @@ void OVERLOAD tabMul(u32 WG, TrigGF31 trig, GF31 *u, u32 n, u32 f, u32 me) { // This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. - if (TABMUL_CHAIN) { + if (TABMUL_CHAIN31) { chainMul (n, u, trig[p], 0); return; } // Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. - if (!TABMUL_CHAIN) { + if (!TABMUL_CHAIN31) { for (u32 i = 1; i < n; ++i) { u[i] = cmul(u[i], trig[(i-1)*WG + p]); } @@ -826,14 +826,14 @@ void OVERLOAD tabMul(u32 WG, TrigGF61 trig, GF61 *u, u32 n, u32 f, u32 me) { // This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. - if (TABMUL_CHAIN) { + if (TABMUL_CHAIN61) { chainMul (n, u, trig[p], 0); return; } // Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. - if (!TABMUL_CHAIN) { + if (!TABMUL_CHAIN61) { for (u32 i = 1; i < n; ++i) { u[i] = cmul(u[i], trig[(i-1)*WG + p]); } diff --git a/src/cl/tailmul.cl b/src/cl/tailmul.cl index b0e234be..1cdd5db0 100644 --- a/src/cl/tailmul.cl +++ b/src/cl/tailmul.cl @@ -320,7 +320,7 @@ KERNEL(G_H) tailMulGF31(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; -#if TAIL_TRIGS >= 1 +#if TAIL_TRIGS31 >= 1 GF31 trig = smallTrig31[height_trigs + me]; // Trig values for line zero, should be cached #if SINGLE_WIDE GF31 mult = smallTrig31[height_trigs + G_H + line1]; @@ -449,7 +449,7 @@ KERNEL(G_H) tailMulGF61(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; -#if TAIL_TRIGS >= 1 +#if TAIL_TRIGS61 >= 1 GF61 trig = smallTrig61[height_trigs + me]; // Trig values for line zero, should be cached #if SINGLE_WIDE GF61 mult = smallTrig61[height_trigs + G_H + line1]; diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index db88d203..bf960f1c 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -409,11 +409,11 @@ KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif // Compute trig values from scratch. Good on GPUs with high DP throughput. -#if TAIL_TRIGS == 2 +#if TAIL_TRIGS32 == 2 F2 trig = slowTrig_N(line1 + me * H, ND / NH); // Do a little bit of memory access and a little bit of DP math. Good on a Radeon VII. -#elif TAIL_TRIGS == 1 +#elif TAIL_TRIGS32 == 1 // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; @@ -454,9 +454,9 @@ KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { } bar(); - fft_HEIGHT(lds, v, smallTrigF2, w); + fft_HEIGHT(lds, v, smallTrigF2); bar(); - fft_HEIGHT(lds, u, smallTrigF2, w); + fft_HEIGHT(lds, u, smallTrigF2); writeTailFusedLine(v, outF2, memline2, me); writeTailFusedLine(u, outF2, memline1, me); @@ -523,11 +523,11 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif // Compute trig values from scratch. Good on GPUs with high DP throughput. -#if TAIL_TRIGS == 2 +#if TAIL_TRIGS32 == 2 F2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); // Do a little bit of memory access and a little bit of DP math. Good on a Radeon VII. -#elif TAIL_TRIGS == 1 +#elif TAIL_TRIGS32 == 1 // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; @@ -647,7 +647,7 @@ KERNEL(G_H) tailSquareZeroGF31(P(T2) out, CP(T2) in, Trig smallTrig) { // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; -#if TAIL_TRIGS >= 1 +#if TAIL_TRIGS31 >= 1 GF31 trig = smallTrig31[height_trigs + me]; #if SINGLE_WIDE GF31 mult = smallTrig31[height_trigs + G_H + line]; @@ -710,12 +710,8 @@ KERNEL(G_H) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { fft_HEIGHT(lds, v, smallTrig31); #endif - // Compute trig values from scratch. Good on GPUs with relatively slow memory. -#if 0 && TAIL_TRIGS == 2 - GF31 trig = slowTrigGF31(line1 + me * H, ND / NH); - // Do a little bit of memory access and a little bit of math. -#elif TAIL_TRIGS >= 1 +#if TAIL_TRIGS31 >= 1 // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; @@ -823,12 +819,8 @@ KERNEL(G_H * 2) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { new_fft_HEIGHT2_1(lds, u, smallTrig31); #endif - // Compute trig values from scratch. Good on GPUs with high MUL throughput?? -#if 0 && TAIL_TRIGS == 2 - GF31 trig = slowTrigGF31(line + H * lowMe, ND / NH * 2); - // Do a little bit of memory access and a little bit of math. Good on a Radeon VII. -#elif TAIL_TRIGS >= 1 +#if TAIL_TRIGS31 >= 1 // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; @@ -948,7 +940,7 @@ KERNEL(G_H) tailSquareZeroGF61(P(T2) out, CP(T2) in, Trig smallTrig) { // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; -#if TAIL_TRIGS >= 1 +#if TAIL_TRIGS61 >= 1 GF61 trig = smallTrig61[height_trigs + me]; #if SINGLE_WIDE GF61 mult = smallTrig61[height_trigs + G_H + line]; @@ -1011,12 +1003,8 @@ KERNEL(G_H) tailSquareGF61(P(T2) out, CP(T2) in, Trig smallTrig) { fft_HEIGHT(lds, v, smallTrig61); #endif - // Compute trig values from scratch. Good on GPUs with relatively slow memory?? -#if 0 && TAIL_TRIGS == 2 - GF61 trig = slowTrigGF61(line1 + me * H, ND / NH); - // Do a little bit of memory access and a little bit of math. -#elif TAIL_TRIGS >= 1 +#if TAIL_TRIGS61 >= 1 // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; @@ -1124,12 +1112,8 @@ KERNEL(G_H * 2) tailSquareGF61(P(T2) out, CP(T2) in, Trig smallTrig) { new_fft_HEIGHT2_1(lds, u, smallTrig61); #endif - // Compute trig values from scratch. Good on GPUs with high MUL throughput?? -#if 0 && TAIL_TRIGS == 2 - GF61 trig = slowTrigGF61(line + H * lowMe, ND / NH * 2); - // Do a little bit of memory access and a little bit of math. Good on a Radeon VII. -#elif TAIL_TRIGS >= 1 +#if TAIL_TRIGS61 >= 1 // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. u32 height_trigs = SMALL_HEIGHT*1; diff --git a/src/cl/tailutil.cl b/src/cl/tailutil.cl index dd1cf41f..01710cf1 100644 --- a/src/cl/tailutil.cl +++ b/src/cl/tailutil.cl @@ -7,7 +7,16 @@ // 1 = Limited memory accesses and some DP computation. Tuned for Radeon VII a GPU with good DP performance. // 0 = No DP computation. Trig vaules read from memory. Good for GPUs with poor DP performance (a typical consumer grade GPU). #if !defined(TAIL_TRIGS) -#define TAIL_TRIGS 2 // Default is compute trig values from scratch +#define TAIL_TRIGS 2 // Default is compute trig values from scratch for FP64 +#endif +#if !defined(TAIL_TRIGS31) +#define TAIL_TRIGS31 0 // Default is read all trig values from memory for GF31 +#endif +#if !defined(TAIL_TRIGS32) +#define TAIL_TRIGS32 2 // Default is compute trig values from scratch for FP32 +#endif +#if !defined(TAIL_TRIGS61) +#define TAIL_TRIGS61 0 // Default is read all trig values from memory for GF61 #endif // TAIL_KERNELS setting: diff --git a/src/common.h b/src/common.h index b4ff3d65..77784b05 100644 --- a/src/common.h +++ b/src/common.h @@ -23,18 +23,9 @@ using namespace std; namespace std::filesystem{}; namespace fs = std::filesystem; -#define FFT_FP64 1 -#define FFT_FP32 0 -#define NTT_GF31 0 -#define NTT_GF61 0 -#define NTT_NCW 0 - -// When using multiple primes in an NTT the size of an integer FFT "word" grows such that we need to support words larger than 32-bits -#if (FFT_FP64 && NTT_GF31) | (FFT_FP32 && NTT_GF61) | (NTT_GF31 && NTT_GF61) +// When using multiple primes in an NTT the size of an integer FFT "word" can be 64 bits. Original FP64 FFT needs only 32 bits. +// C code will use i64 integer data. The code that reads and writes GPU buffers will downsize the integers to 32 bits when required. typedef i64 Word; -#else -typedef i32 Word; -#endif using double2 = pair; using float2 = pair; diff --git a/src/fftbpw.h b/src/fftbpw.h index 1e8712d9..928f45be 100644 --- a/src/fftbpw.h +++ b/src/fftbpw.h @@ -91,3 +91,101 @@ { "4K:14:1K", {16.942, 17.055, 17.160, 17.027, 17.127, 17.306}}, { "4K:15:1K", {17.021, 17.007, 17.137, 17.104, 17.087, 17.282}}, { "4K:16:1K", {16.744, 16.887, 16.966, 16.921, 17.048, 17.208}}, +// FFT3161 +{ "1:256:2:256", {39.74, 39.74, 39.74, 39.74, 39.74, 39.74}}, +{ "1:256:4:256", {39.64, 39.64, 39.64, 39.64, 39.64, 39.64}}, +{ "1:256:8:256", {39.54, 39.54, 39.54, 39.54, 39.54, 39.54}}, +{ "1:512:4:256", {39.54, 39.54, 39.54, 39.54, 39.54, 39.54}}, +{"1:256:16:256", {39.44, 39.44, 39.44, 39.44, 39.44, 39.44}}, +{ "1:512:8:256", {39.44, 39.44, 39.44, 39.44, 39.44, 39.44}}, +{ "1:512:4:512", {39.44, 39.44, 39.44, 39.44, 39.44, 39.44}}, +{ "1:1K:8:256", {39.34, 39.34, 39.34, 39.34, 39.34, 39.34}}, +{"1:512:16:256", {39.34, 39.34, 39.34, 39.34, 39.34, 39.34}}, +{ "1:512:8:512", {39.34, 39.34, 39.34, 39.34, 39.34, 39.34}}, +{ "1:1K:16:256", {39.24, 39.24, 39.24, 39.24, 39.24, 39.24}}, +{ "1:1K:8:512", {39.24, 39.24, 39.24, 39.24, 39.24, 39.24}}, +{"1:512:16:512", {39.24, 39.24, 39.24, 39.24, 39.24, 39.24}}, +{ "1:1K:16:512", {39.14, 39.14, 39.14, 39.14, 39.14, 39.14}}, +{ "1:1K:8:1K", {39.14, 39.14, 39.14, 39.14, 39.14, 39.14}}, +{ "1:1K:16:1K", {39.04, 39.04, 39.04, 39.04, 39.04, 39.04}}, +{ "1:4K:16:512", {38.94, 38.94, 38.94, 38.94, 38.94, 38.94}}, +{ "1:4K:16:1K", {38.84, 38.84, 38.84, 38.84, 38.84, 38.84}}, +// FFT3261 +{ "2:256:2:256", {32.05, 32.35, 32.35, 32.35, 32.35, 32.35}}, +{ "2:256:4:256", {31.95, 32.25, 32.25, 32.25, 32.25, 32.25}}, +{ "2:256:8:256", {31.85, 32.15, 32.15, 32.15, 32.15, 32.15}}, +{ "2:512:4:256", {31.85, 32.15, 32.15, 32.15, 32.15, 32.15}}, +{"2:256:16:256", {31.75, 32.05, 32.05, 32.05, 32.05, 32.05}}, +{ "2:512:8:256", {31.75, 32.05, 32.05, 32.05, 32.05, 32.05}}, +{ "2:512:4:512", {31.75, 32.05, 32.05, 32.05, 32.05, 32.05}}, +{ "2:1K:8:256", {31.65, 31.95, 31.95, 31.95, 31.95, 31.95}}, +{"2:512:16:256", {31.65, 31.95, 31.95, 31.95, 31.95, 31.95}}, +{ "2:512:8:512", {31.65, 31.95, 31.95, 31.95, 31.95, 31.95}}, +{ "2:1K:16:256", {31.55, 31.85, 31.85, 31.85, 31.85, 31.85}}, +{ "2:1K:8:512", {31.55, 31.85, 31.85, 31.85, 31.85, 31.85}}, +{"2:512:16:512", {31.55, 31.85, 31.85, 31.85, 31.85, 31.85}}, +{ "2:1K:16:512", {31.45, 31.75, 31.75, 31.75, 31.75, 31.75}}, +{ "2:1K:8:1K", {31.45, 31.75, 31.75, 31.75, 31.75, 31.75}}, +{ "2:1K:16:1K", {31.35, 31.65, 31.65, 31.65, 31.65, 31.65}}, +{ "2:4K:16:512", {31.25, 31.55, 31.55, 31.55, 31.55, 31.55}}, +{ "2:4K:16:1K", {31.15, 31.45, 31.45, 31.45, 31.45, 31.45}}, +// FFT61 +{ "3:256:2:256", {24.20, 24.20, 24.20, 24.20, 24.20, 24.20}}, +{ "3:256:4:256", {24.10, 24.10, 24.10, 24.10, 24.10, 24.10}}, +{ "3:256:8:256", {24.00, 24.00, 24.00, 24.00, 24.00, 24.00}}, +{ "3:512:4:256", {24.00, 24.00, 24.00, 24.00, 24.00, 24.00}}, +{"3:256:16:256", {23.90, 23.90, 23.90, 23.90, 23.90, 23.90}}, +{ "3:512:8:256", {23.90, 23.90, 23.90, 23.90, 23.90, 23.90}}, +{ "3:512:4:512", {23.90, 23.90, 23.90, 23.90, 23.90, 23.90}}, +{ "3:1K:8:256", {23.80, 23.80, 23.80, 23.80, 23.80, 23.80}}, +{"3:512:16:256", {23.80, 23.80, 23.80, 23.80, 23.80, 23.80}}, +{ "3:512:8:512", {23.80, 23.80, 23.80, 23.80, 23.80, 23.80}}, +{ "3:1K:16:256", {23.70, 23.70, 23.70, 23.70, 23.70, 23.70}}, +{ "3:1K:8:512", {23.70, 23.70, 23.70, 23.70, 23.70, 23.70}}, +{"3:512:16:512", {23.70, 23.70, 23.70, 23.70, 23.70, 23.70}}, +{ "3:1K:16:512", {23.60, 23.60, 23.60, 23.60, 23.60, 23.60}}, +{ "3:1K:8:1K", {23.60, 23.60, 23.60, 23.60, 23.60, 23.60}}, +{ "3:1K:16:1K", {23.50, 23.50, 23.50, 23.50, 23.50, 23.50}}, +{ "3:4K:16:512", {23.40, 23.40, 23.40, 23.40, 23.40, 23.40}}, +{ "3:4K:16:1K", {23.30, 23.30, 23.30, 23.30, 23.30, 23.30}}, +// FFT3231 +{ "50:256:2:256", {16.95, 34.26, 34.26, 34.26, 34.26, 34.26}}, +{ "50:256:4:256", {16.85, 34.16, 34.16, 34.16, 34.16, 34.16}}, +{ "50:256:8:256", {16.75, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{ "50:512:4:256", {16.75, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{"50:256:16:256", {16.65, 33.96, 33.96, 33.96, 33.96, 33.96}}, +{ "50:512:8:256", {16.65, 33.96, 33.96, 33.96, 33.96, 33.96}}, +{ "50:512:4:512", {16.65, 33.96, 33.96, 33.96, 33.96, 33.96}}, +{ "50:1K:8:256", {16.55, 33.86, 33.86, 33.86, 33.86, 33.86}}, +{"50:512:16:256", {16.55, 33.86, 33.86, 33.86, 33.86, 33.86}}, +{ "50:512:8:512", {16.55, 33.86, 33.86, 33.86, 33.86, 33.86}}, +{ "50:1K:16:256", {16.45, 33.76, 33.76, 33.76, 33.76, 33.76}}, +{ "50:1K:8:512", {16.45, 33.76, 33.76, 33.76, 33.76, 33.76}}, +{"50:512:16:512", {16.45, 33.76, 33.76, 33.76, 33.76, 33.76}}, +{ "50:1K:16:512", {16.35, 33.66, 33.66, 33.66, 33.66, 33.66}}, +{ "50:1K:8:1K", {16.35, 33.66, 33.66, 33.66, 33.66, 33.66}}, +{ "50:1K:16:1K", {16.25, 33.56, 33.56, 33.56, 33.56, 33.56}}, +{ "50:4K:16:512", {16.15, 33.46, 33.46, 33.46, 33.46, 33.46}}, +{ "50:4K:16:1K", {16.05, 33.36, 33.36, 33.36, 33.36, 33.36}}, +// FFT6431 +{ "51:256:2:256", {34.26, 34.26, 34.26, 34.26, 34.26, 34.26}}, +{ "51:256:4:256", {34.16, 34.16, 34.16, 34.16, 34.16, 34.16}}, +{ "51:256:8:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{ "51:512:4:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{"51:256:16:256", {33.96, 33.96, 33.96, 33.96, 33.96, 33.96}}, +{ "51:512:8:256", {33.96, 33.96, 33.96, 33.96, 33.96, 33.96}}, +{ "51:512:4:512", {33.96, 33.96, 33.96, 33.96, 33.96, 33.96}}, +{ "51:1K:8:256", {33.86, 33.86, 33.86, 33.86, 33.86, 33.86}}, +{"51:512:16:256", {33.86, 33.86, 33.86, 33.86, 33.86, 33.86}}, +{ "51:512:8:512", {33.86, 33.86, 33.86, 33.86, 33.86, 33.86}}, +{ "51:1K:16:256", {33.76, 33.76, 33.76, 33.76, 33.76, 33.76}}, +{ "51:1K:8:512", {33.76, 33.76, 33.76, 33.76, 33.76, 33.76}}, +{"51:512:16:512", {33.76, 33.76, 33.76, 33.76, 33.76, 33.76}}, +{ "51:1K:16:512", {33.66, 33.66, 33.66, 33.66, 33.66, 33.66}}, +{ "51:1K:8:1K", {33.66, 33.66, 33.66, 33.66, 33.66, 33.66}}, +{ "51:1K:16:1K", {33.56, 33.56, 33.56, 33.56, 33.56, 33.56}}, +{ "51:4K:16:512", {33.46, 33.46, 33.46, 33.46, 33.46, 33.46}}, +{ "51:4K:16:1K", {33.36, 33.36, 33.36, 33.36, 33.36, 33.36}}, + + + diff --git a/src/tune.cpp b/src/tune.cpp index e14fbe46..774cccde 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -103,26 +103,26 @@ string formatConfigResults(const vector& results) { } // namespace -double Tune::maxBpw(FFTConfig fft) { +float Tune::maxBpw(FFTConfig fft) { -// double bpw = oldBpw; +// float bpw = oldBpw; - const double TARGET = 28; + const float TARGET = 28; const u32 sample_size = 5; // Estimate how much bpw needs to change to increase/decrease Z by 1. // This doesn't need to be a very accurate estimate. // This estimate comes from analyzing a 4M FFT and a 7.5M FFT. // The 4M FFT needed a .015 step, the 7.5M FFT needed a .012 step. - double bpw_step = .015 + (log2(fft.size()) - log2(4.0*1024*1024)) / (log2(7.5*1024*1024) - log2(4.0*1024*1024)) * (.012 - .015); + float bpw_step = .015 + (log2(fft.size()) - log2(4.0*1024*1024)) / (log2(7.5*1024*1024) - log2(4.0*1024*1024)) * (.012 - .015); // Pick a bpw that might be close to Z=34, it is best to err on the high side of Z=34 - double bpw1 = fft.maxBpw() - 9 * bpw_step; // Old bpw gave Z=28, we want Z=34 (or more) + float bpw1 = fft.maxBpw() - 9 * bpw_step; // Old bpw gave Z=28, we want Z=34 (or more) // The code below was used when building the maxBpw table from scratch // u32 non_best_width = N_VARIANT_W - 1 - variant_W(fft.variant); // Number of notches below best-Z width variant // u32 non_best_middle = N_VARIANT_M - 1 - variant_M(fft.variant); // Number of notches below best-Z middle variant -// double bpw1 = 18.3 - 0.275 * (log2(fft.size()) - log2(256 * 13 * 1024 * 2)) - // Default max bpw from an old gpuowl version +// float bpw1 = 18.3 - 0.275 * (log2(fft.size()) - log2(256 * 13 * 1024 * 2)) - // Default max bpw from an old gpuowl version // 9 * bpw_step - // Default above should give Z=28, we want Z=34 (or more) // (.08/.012 * bpw_step) * non_best_width - // 7.5M FFT has ~.08 bpw difference for each width variant below best variant // (.06 + .04 * (fft.shape.middle - 4) / 11) * non_best_middle; // Assume .1 bpw difference MIDDLE=15 and .06 for MIDDLE=4 @@ -130,11 +130,11 @@ double Tune::maxBpw(FFTConfig fft) { //if (fft.size() < 512000) bpw1 = 19, bpw_step = .02; // Fine tune our estimate for Z=34 - double z1 = zForBpw(bpw1, fft, 1); + float z1 = zForBpw(bpw1, fft, 1); printf ("Guess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bpw1, z1); while (z1 < 31.0 || z1 > 37.0) { - double prev_bpw1 = bpw1; - double prev_z1 = z1; + float prev_bpw1 = bpw1; + float prev_z1 = z1; bpw1 = bpw1 + (z1 - 34) * bpw_step; z1 = zForBpw(bpw1, fft, 1); printf ("Reguess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bpw1, z1); @@ -147,12 +147,12 @@ printf ("Reguess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bp z1 = (z1 + (sample_size - 1) * zForBpw(bpw1, fft, sample_size - 1)) / sample_size; // Pick a bpw somewhere near Z=22 then fine tune the guess - double bpw2 = bpw1 + (z1 - 22) * bpw_step; - double z2 = zForBpw(bpw2, fft, 1); + float bpw2 = bpw1 + (z1 - 22) * bpw_step; + float z2 = zForBpw(bpw2, fft, 1); printf ("Guess bpw for %s is %.2f first Z22 is %.2f\n", fft.spec().c_str(), bpw2, z2); while (z2 < 20.0 || z2 > 25.0) { - double prev_bpw2 = bpw2; - double prev_z2 = z2; + float prev_bpw2 = bpw2; + float prev_z2 = z2; // bool error_recovery = (z2 <= 0.0); // if (error_recovery) bpw2 -= bpw_step; else bpw2 = bpw2 + (z2 - 21) * bpw_step; @@ -171,12 +171,12 @@ printf ("Reguess bpw for %s is %.2f first Z22 is %.2f\n", fft.spec().c_str(), bp return bpw2 + (bpw1 - bpw2) * (TARGET - z2) / (z1 - z2); } -double Tune::zForBpw(double bpw, FFTConfig fft, u32 count) { +float Tune::zForBpw(float bpw, FFTConfig fft, u32 count) { u32 exponent = (count == 1) ? primes.prevPrime(fft.size() * bpw) : primes.nextPrime(fft.size() * bpw); - double total_z = 0.0; + float total_z = 0.0; for (u32 i = 0; i < count; i++, exponent = primes.nextPrime (exponent + 1)) { auto [ok, res, roeSq, roeMul] = Gpu::make(q, exponent, shared, fft, {}, false)->measureROE(true); - double z = roeSq.z(); + float z = roeSq.z(); total_z += z; log("Zforbpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); if (!ok) { log("Error at bpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); continue; } @@ -195,8 +195,8 @@ void Tune::ztune() { u32 variant = 202; u32 sample_size = 5; FFTConfig fft{shape, variant, CARRY_AUTO}; - for (double bpw = 18.18; bpw < 18.305; bpw += 0.02) { - double z = zForBpw(bpw, fft, sample_size); + for (float bpw = 18.18; bpw < 18.305; bpw += 0.02) { + float z = zForBpw(bpw, fft, sample_size); log ("Avg zForBpw %s %.2f %.2f\n", fft.spec().c_str(), bpw, z); } } @@ -213,7 +213,7 @@ void Tune::ztune() { if (shape.width > 1024) bpw_variants[0] = 100, bpw_variants[3] = 110; // Copy the existing bpw array (in case we're replacing only some of the entries) - array bpw; + array bpw; bpw = shape.bpw; // Not all shapes have their maximum bpw per-computed. But one can work on a non-favored shape by specifying it on the command line. @@ -244,10 +244,10 @@ void Tune::carryTune() { if (prevSize == fft.size()) { continue; } prevSize = fft.size(); - vector zv; + vector zv; double m = 0; - const double mid = fft.shape.carry32BPW(); - for (double bpw : {mid - 0.05, mid + 0.05}) { + const float mid = fft.shape.carry32BPW(); + for (float bpw : {mid - 0.05, mid + 0.05}) { u32 exponent = primes.nearestPrime(fft.size() * bpw); auto [ok, carry] = Gpu::make(q, exponent, shared, fft, {}, false)->measureCarry(); m = carry.max; @@ -255,7 +255,7 @@ void Tune::carryTune() { zv.push_back(carry.z()); } - double avg = (zv[0] + zv[1]) / 2; + float avg = (zv[0] + zv[1]) / 2; u32 exponent = fft.shape.carry32BPW() * fft.size(); double pErr100 = -expm1(-exp(-avg) * exponent * 100); log("%14s %.3f : %.3f (%.3f %.3f) %f %.0f%%\n", fft.spec().c_str(), mid, avg, zv[0], zv[1], m, pErr100 * 100); @@ -325,288 +325,574 @@ void Tune::ctune() { log("\nBest configs (lines can be copied to config.txt):\n%s", formatConfigResults(results).c_str()); } +// Add better -use settings to list of changes to be made to config.txt +void configsUpdate(double current_cost, double best_cost, double threshold, const char *key, u32 value, vector> &newConfigKeyVals, vector> &suggestedConfigKeyVals) { + if (best_cost == current_cost) return; + // If best cost is better than current cost by a substantial margin (the threshold) then add the key value pair to suggestedConfigKeyVals + if (best_cost < (1.0 - threshold) * current_cost) + newConfigKeyVals.push_back({key, value}); + // Otherwise, add the key value pair to newConfigKeyVals + else + suggestedConfigKeyVals.push_back({key, value}); +} + void Tune::tune() { Args *args = shared.args; vector shapes = FFTShape::multiSpec(args->fftSpec); - + // There are some options and variants that are different based on GPU manufacturer bool AMDGPU = isAmdGpu(q->context->deviceId()); - // Look for best settings of various options + bool tune_config = 1; + bool time_FFTs = 0; + bool time_NTTs = 0; + u64 min_exponent = 75000000; + u64 max_exponent = 350000000; + if (!args->fftSpec.empty()) { min_exponent = 0; max_exponent = 1000000000000ull; } + + // Parse input args + for (const string& s : split(args->tune, ',')) { + if (s.empty()) continue; + if (s == "noconfig") tune_config = 0; + if (s == "fp64") time_FFTs = 1; + if (s == "ntt") time_NTTs = 1; + auto keyVal = split(s, '='); + if (keyVal.size() == 2) { + if (keyVal.front() == "minexp") min_exponent = stoull(keyVal.back()); + if (keyVal.front() == "maxexp") max_exponent = stoull(keyVal.back()); + } + } - if (1) { - u32 variant = 101; + // Look for best settings of various options. Append best settings to config.txt. + if (tune_config) { + vector> newConfigKeyVals; + vector> suggestedConfigKeyVals; + + // Select/init the default FFTshape(s) and FFTConfig(s) for optimal -use options testing + FFTShape defaultFFTShape, defaultNTTShape, *defaultShape; + + // If user gave us an fft-spec, use that to time options + if (!args->fftSpec.empty()) { + defaultShape = &shapes[0]; + if (shapes[0].fft_type == FFT64) { + defaultFFTShape = shapes[0]; + time_FFTs = 1; + } else { + defaultNTTShape = shapes[0]; + time_NTTs = 1; + } + } + // If user specified FP64-timings, time a wavefront exponent using an 7.5M FFT + // If user specified NTT-timings, time a wavefront exponent using an 4M M31+M61 NTT + else if (time_FFTs || time_NTTs) { + if (time_FFTs) { + defaultFFTShape = FFTShape(FFT64, 512, 15, 512); + defaultShape = &defaultFFTShape; + } + if (time_NTTs) { + defaultNTTShape = FFTShape(FFT3161, 512, 8, 512); + defaultShape = &defaultNTTShape; + } + } + // No user specifications. Time an FP64 FFT and a GF31*GF61 NTT to see if the GPU is more suited for FP64 work or NTT work. + else { + log("Checking whether this GPU is better suited for double-precision FFTs or integer NTTs.\n"); + defaultFFTShape = FFTShape(FFT64, 512, 16, 512); + FFTConfig fft{defaultFFTShape, 101, CARRY_32}; + double fp64_time = Gpu::make(q, 141000001, shared, fft, {}, false)->timePRP(); + log("Time for FP64 FFT %12s is %6.1f\n", fft.spec().c_str(), fp64_time); + defaultNTTShape = FFTShape(FFT3161, 512, 8, 512); + FFTConfig ntt{defaultNTTShape, 202, CARRY_AUTO}; + double ntt_time = Gpu::make(q, 141000001, shared, ntt, {}, false)->timePRP(); + log("Time for M31*M61 NTT %12s is %6.1f\n", ntt.spec().c_str(), ntt_time); + if (fp64_time < ntt_time) { + defaultShape = &defaultFFTShape; + time_FFTs = 1; + if (fp64_time < 0.80 * ntt_time) { + log("FP64 FFTs are significantly faster than NTTs. No NTT tuning will be performed.\n"); + } else { + log("FP64 FFTs are not significantly faster than NTTs. NTT tuning will be performed.\n"); + time_NTTs = 1; + } + } else { + defaultShape = &defaultNTTShape; + time_NTTs = 1; + if (fp64_time > 1.20 * ntt_time) { + log("FP64 FFTs are significantly slower than NTTs. No FP64 tuning will be performed.\n"); + } else { + log("FP64 FFTs are not significantly slower than NTTs. FP64 tuning will be performed.\n"); + time_FFTs = 1; + } + } + } + + log("\n"); + log("Beginning timing of various options. These settings will be appended to config.txt. Please read config.txt after -tune completes.\n"); + log("\n"); + + u32 variant = (defaultShape == &defaultFFTShape) ? 101 : 202; //GW: if fft spec on the command line specifies a variant then we should use that variant (I get some interesting results with 000 vs 101 vs 201 vs 202 likely due to rocm optimizer) - // Find best IN_WG,IN_SIZEX setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best IN_WG,IN_SIZEX,OUT_WG,OUT_SIZEX settings + if (1/*option to time IN/OUT settings*/) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_in_wg = 0; u32 best_in_sizex = 0; + u32 current_in_wg = args->value("IN_WG", 128); + u32 current_in_sizex = args->value("IN_SIZEX", 16); double best_cost = -1.0; + double current_cost = -1.0; for (u32 in_wg : {64, 128, 256}) { for (u32 in_sizex : {8, 16, 32}) { - shared.args->flags["IN_WG"] = to_string(in_wg); - shared.args->flags["IN_SIZEX"] = to_string(in_sizex); + args->flags["IN_WG"] = to_string(in_wg); + args->flags["IN_SIZEX"] = to_string(in_sizex); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using IN_WG=%u, IN_SIZEX=%u is %6.1f\n", fft.spec().c_str(), in_wg, in_sizex, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_in_wg = in_wg; best_in_sizex = in_sizex; } - } + if (in_wg == current_in_wg && in_sizex == current_in_sizex) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_in_wg = in_wg; best_in_sizex = in_sizex; } + } } log("Best IN_WG, IN_SIZEX is %u, %u. Default is 128, 16.\n", best_in_wg, best_in_sizex); - shared.args->flags["IN_WG"] = to_string(best_in_wg); - shared.args->flags["IN_SIZEX"] = to_string(best_in_sizex); - } + configsUpdate(current_cost, best_cost, 0.003, "IN_WG", best_in_wg, newConfigKeyVals, suggestedConfigKeyVals); + configsUpdate(current_cost, best_cost, 0.003, "IN_SIZEX", best_in_sizex, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["IN_WG"] = to_string(best_in_wg); + args->flags["IN_SIZEX"] = to_string(best_in_sizex); - // Find best OUT_WG,OUT_SIZEX setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; - u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_out_wg = 0; u32 best_out_sizex = 0; - double best_cost = -1.0; + u32 current_out_wg = args->value("OUT_WG", 128); + u32 current_out_sizex = args->value("OUT_SIZEX", 16); + best_cost = -1.0; + current_cost = -1.0; for (u32 out_wg : {64, 128, 256}) { for (u32 out_sizex : {8, 16, 32}) { - shared.args->flags["OUT_WG"] = to_string(out_wg); - shared.args->flags["OUT_SIZEX"] = to_string(out_sizex); + args->flags["OUT_WG"] = to_string(out_wg); + args->flags["OUT_SIZEX"] = to_string(out_sizex); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using OUT_WG=%u, OUT_SIZEX=%u is %6.1f\n", fft.spec().c_str(), out_wg, out_sizex, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_out_wg = out_wg; best_out_sizex = out_sizex; } - } + if (out_wg == current_out_wg && out_sizex == current_out_sizex) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_out_wg = out_wg; best_out_sizex = out_sizex; } + } } log("Best OUT_WG, OUT_SIZEX is %u, %u. Default is 128, 16.\n", best_out_wg, best_out_sizex); - shared.args->flags["OUT_WG"] = to_string(best_out_wg); - shared.args->flags["OUT_SIZEX"] = to_string(best_out_sizex); + configsUpdate(current_cost, best_cost, 0.003, "OUT_WG", best_out_wg, newConfigKeyVals, suggestedConfigKeyVals); + configsUpdate(current_cost, best_cost, 0.003, "OUT_SIZEX", best_out_sizex, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["OUT_WG"] = to_string(best_out_wg); + args->flags["OUT_SIZEX"] = to_string(best_out_sizex); + } + + // Find best PAD setting. Default is 256 bytes for AMD, 0 for all others. + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_pad = 0; + u32 current_pad = args->value("PAD", AMDGPU ? 256 : 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 pad : {0, 64, 128, 256, 512}) { + args->flags["PAD"] = to_string(pad); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using PAD=%u is %6.1f\n", fft.spec().c_str(), pad, cost); + if (pad == current_pad) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_pad = pad; } + } + log("Best PAD is %u bytes. Default PAD is %u bytes.\n", best_pad, AMDGPU ? 256 : 0); + configsUpdate(current_cost, best_cost, 0.000, "PAD", best_pad, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["PAD"] = to_string(best_pad); + } + + // Find best NONTEMPORAL setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_nontemporal = 0; + u32 current_nontemporal = args->value("NONTEMPORAL", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 nontemporal : {0, 1}) { + args->flags["NONTEMPORAL"] = to_string(nontemporal); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); + if (nontemporal == current_nontemporal) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_nontemporal = nontemporal; } + } + log("Best NONTEMPORAL is %u. Default NONTEMPORAL is 0.\n", best_nontemporal); + configsUpdate(current_cost, best_cost, 0.000, "NONTEMPORAL", best_nontemporal, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["NONTEMPORAL"] = to_string(best_nontemporal); } // Find best FAST_BARRIER setting - if (1 && AMDGPU) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + if (AMDGPU) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_fast_barrier = 0; + u32 current_fast_barrier = args->value("FAST_BARRIER", 0); double best_cost = -1.0; + double current_cost = -1.0; for (u32 fast_barrier : {0, 1}) { - shared.args->flags["FAST_BARRIER"] = to_string(fast_barrier); + args->flags["FAST_BARRIER"] = to_string(fast_barrier); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using FAST_BARRIER=%u is %6.1f\n", fft.spec().c_str(), fast_barrier, cost); + if (fast_barrier == current_fast_barrier) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_fast_barrier = fast_barrier; } } log("Best FAST_BARRIER is %u. Default FAST_BARRIER is 0.\n", best_fast_barrier); - shared.args->flags["FAST_BARRIER"] = to_string(best_fast_barrier); + configsUpdate(current_cost, best_cost, 0.000, "FAST_BARRIER", best_fast_barrier, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["FAST_BARRIER"] = to_string(best_fast_barrier); } - // Find best TAIL_TRIGS setting + // Find best TAIL_KERNELS setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_kernels = 0; + u32 current_tail_kernels = args->value("TAIL_KERNELS", 2); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tail_kernels : {0, 1, 2, 3}) { + args->flags["TAIL_KERNELS"] = to_string(tail_kernels); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TAIL_KERNELS=%u is %6.1f\n", fft.spec().c_str(), tail_kernels, cost); + if (tail_kernels == current_tail_kernels) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_kernels = tail_kernels; } + } + if (best_tail_kernels & 1) + log("Best TAIL_KERNELS is %u. Default TAIL_KERNELS is 2.\n", best_tail_kernels); + else + log("Best TAIL_KERNELS is %u (but best may be %u when running two workers on one GPU). Default TAIL_KERNELS is 2.\n", best_tail_kernels, best_tail_kernels | 1); + configsUpdate(current_cost, best_cost, 0.000, "TAIL_KERNELS", best_tail_kernels, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_KERNELS"] = to_string(best_tail_kernels); + } + + // Find best TAIL_TRIGS setting + if (time_FFTs) { + FFTConfig fft{defaultFFTShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS", 2); double best_cost = -1.0; + double current_cost = -1.0; for (u32 tail_trigs : {0, 1, 2}) { - shared.args->flags["TAIL_TRIGS"] = to_string(tail_trigs); + args->flags["TAIL_TRIGS"] = to_string(tail_trigs); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using TAIL_TRIGS=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } } log("Best TAIL_TRIGS is %u. Default TAIL_TRIGS is 2.\n", best_tail_trigs); - shared.args->flags["TAIL_TRIGS"] = to_string(best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS"] = to_string(best_tail_trigs); } - // Find best TAIL_KERNELS setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best TAIL_TRIGS31 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF31) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_tail_kernels = 0; + u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS31", 0); double best_cost = -1.0; - for (u32 tail_kernels : {0, 1, 2, 3}) { - shared.args->flags["TAIL_KERNELS"] = to_string(tail_kernels); + double current_cost = -1.0; + for (u32 tail_trigs : {0, 1}) { + args->flags["TAIL_TRIGS31"] = to_string(tail_trigs); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using TAIL_KERNELS=%u is %6.1f\n", fft.spec().c_str(), tail_kernels, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_kernels = tail_kernels; } + log("Time for %12s using TAIL_TRIGS31=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } } - if (best_tail_kernels & 1) - log("Best TAIL_KERNELS is %u. Default TAIL_KERNELS is 2.\n", best_tail_kernels); - else - log("Best TAIL_KERNELS is %u (but best may be %u when running two workers on one GPU). Default TAIL_KERNELS is 2.\n", best_tail_kernels, best_tail_kernels | 1); - shared.args->flags["TAIL_KERNELS"] = to_string(best_tail_kernels); + log("Best TAIL_TRIGS31 is %u. Default TAIL_TRIGS31 is 0.\n", best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS31", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS31"] = to_string(best_tail_trigs); + } + + // Find best TAIL_TRIGS32 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS32", 2); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tail_trigs : {0, 1, 2}) { + args->flags["TAIL_TRIGS32"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TAIL_TRIGS32=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } + } + log("Best TAIL_TRIGS32 is %u. Default TAIL_TRIGS32 is 2.\n", best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS32", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS32"] = to_string(best_tail_trigs); + } + + // Find best TAIL_TRIGS61 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF61) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS61", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tail_trigs : {0, 1}) { + args->flags["TAIL_TRIGS61"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TAIL_TRIGS61=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } + } + log("Best TAIL_TRIGS61 is %u. Default TAIL_TRIGS61 is 0.\n", best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS61", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS61"] = to_string(best_tail_trigs); } // Find best TABMUL_CHAIN setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, 101, CARRY_32}; + if (time_FFTs) { + FFTConfig fft{defaultFFTShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN", 0); double best_cost = -1.0; + double current_cost = -1.0; for (u32 tabmul_chain : {0, 1}) { - shared.args->flags["TABMUL_CHAIN"] = to_string(tabmul_chain); + args->flags["TABMUL_CHAIN"] = to_string(tabmul_chain); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using TABMUL_CHAIN=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } } log("Best TABMUL_CHAIN is %u. Default TABMUL_CHAIN is 0.\n", best_tabmul_chain); - shared.args->flags["TABMUL_CHAIN"] = to_string(best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN"] = to_string(best_tabmul_chain); } - // Find best PAD setting. Default is 256 bytes for AMD, 0 for all others. - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best TABMUL_CHAIN31 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF31) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_pad = 0; + u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN31", 0); double best_cost = -1.0; - for (u32 pad : {0, 64, 128, 256, 512}) { - shared.args->flags["PAD"] = to_string(pad); + double current_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + args->flags["TABMUL_CHAIN31"] = to_string(tabmul_chain); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using PAD=%u is %6.1f\n", fft.spec().c_str(), pad, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_pad = pad; } + log("Time for %12s using TABMUL_CHAIN31=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } } - log("Best PAD is %u bytes. Default PAD is %u bytes.\n", best_pad, AMDGPU ? 256 : 0); - shared.args->flags["PAD"] = to_string(best_pad); + log("Best TABMUL_CHAIN31 is %u. Default TABMUL_CHAIN31 is 0.\n", best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN31", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN31"] = to_string(best_tabmul_chain); } - // Find best NONTEMPORAL setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best TABMUL_CHAIN61 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_nontemporal = 0; + u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN32", 0); double best_cost = -1.0; - for (u32 nontemporal : {0, 1}) { - shared.args->flags["NONTEMPORAL"] = to_string(nontemporal); + double current_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + args->flags["TABMUL_CHAIN32"] = to_string(tabmul_chain); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_nontemporal = nontemporal; } + log("Time for %12s using TABMUL_CHAIN32=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } } - log("Best NONTEMPORAL is %u. Default NONTEMPORAL is 0.\n", best_nontemporal); - shared.args->flags["NONTEMPORAL"] = to_string(best_nontemporal); + log("Best TABMUL_CHAIN32 is %u. Default TABMUL_CHAIN32 is 0.\n", best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN32", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN32"] = to_string(best_tabmul_chain); + } + + // Find best TABMUL_CHAIN61 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF61) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN61", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + args->flags["TABMUL_CHAIN61"] = to_string(tabmul_chain); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + log("Time for %12s using TABMUL_CHAIN61=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } + } + log("Best TABMUL_CHAIN61 is %u. Default TABMUL_CHAIN61 is 0.\n", best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN61", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN61"] = to_string(best_tabmul_chain); } // Find best UNROLL_W setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_unroll_w = 0; + u32 current_unroll_w = args->value("UNROLL_W", AMDGPU ? 0 : 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 unroll_w : {0, 1}) { - shared.args->flags["UNROLL_W"] = to_string(unroll_w); + args->flags["UNROLL_W"] = to_string(unroll_w); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using UNROLL_W=%u is %6.1f\n", fft.spec().c_str(), unroll_w, cost); + if (unroll_w == current_unroll_w) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_w = unroll_w; } } log("Best UNROLL_W is %u. Default UNROLL_W is %u.\n", best_unroll_w, AMDGPU ? 0 : 1); - shared.args->flags["UNROLL_W"] = to_string(best_unroll_w); + configsUpdate(current_cost, best_cost, 0.003, "UNROLL_W", best_unroll_w, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["UNROLL_W"] = to_string(best_unroll_w); } // Find best UNROLL_H setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_unroll_h = 0; + u32 current_unroll_h = args->value("UNROLL_H", AMDGPU && defaultShape->height >= 1024 ? 0 : 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 unroll_h : {0, 1}) { - shared.args->flags["UNROLL_H"] = to_string(unroll_h); + args->flags["UNROLL_H"] = to_string(unroll_h); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using UNROLL_H=%u is %6.1f\n", fft.spec().c_str(), unroll_h, cost); + if (unroll_h == current_unroll_h) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_h = unroll_h; } } - log("Best UNROLL_H is %u. Default UNROLL_H is %u.\n", best_unroll_h, AMDGPU && shape.height >= 1024 ? 0 : 1); - shared.args->flags["UNROLL_H"] = to_string(best_unroll_h); + log("Best UNROLL_H is %u. Default UNROLL_H is %u.\n", best_unroll_h, AMDGPU && defaultShape->height >= 1024 ? 0 : 1); + configsUpdate(current_cost, best_cost, 0.003, "UNROLL_H", best_unroll_h, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["UNROLL_H"] = to_string(best_unroll_h); } // Find best ZEROHACK_W setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_zerohack_w = 0; + u32 current_zerohack_w = args->value("ZEROHACK_W", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 zerohack_w : {0, 1}) { - shared.args->flags["ZEROHACK_W"] = to_string(zerohack_w); + args->flags["ZEROHACK_W"] = to_string(zerohack_w); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using ZEROHACK_W=%u is %6.1f\n", fft.spec().c_str(), zerohack_w, cost); + if (zerohack_w == current_zerohack_w) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_w = zerohack_w; } } log("Best ZEROHACK_W is %u. Default ZEROHACK_W is 1.\n", best_zerohack_w); - shared.args->flags["ZEROHACK_W"] = to_string(best_zerohack_w); + configsUpdate(current_cost, best_cost, 0.003, "ZEROHACK_W", best_zerohack_w, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["ZEROHACK_W"] = to_string(best_zerohack_w); } // Find best ZEROHACK_H setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_zerohack_h = 0; + u32 current_zerohack_h = args->value("ZEROHACK_H", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 zerohack_h : {0, 1}) { - shared.args->flags["ZEROHACK_H"] = to_string(zerohack_h); + args->flags["ZEROHACK_H"] = to_string(zerohack_h); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using ZEROHACK_H=%u is %6.1f\n", fft.spec().c_str(), zerohack_h, cost); + if (zerohack_h == current_zerohack_h) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_h = zerohack_h; } } log("Best ZEROHACK_H is %u. Default ZEROHACK_H is 1.\n", best_zerohack_h); - shared.args->flags["ZEROHACK_H"] = to_string(best_zerohack_h); + configsUpdate(current_cost, best_cost, 0.003, "ZEROHACK_H", best_zerohack_h, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["ZEROHACK_H"] = to_string(best_zerohack_h); } // Find best MIDDLE_IN_LDS_TRANSPOSE setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_middle_in_lds_transpose = 0; + u32 current_middle_in_lds_transpose = args->value("MIDDLE_IN_LDS_TRANSPOSE", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 middle_in_lds_transpose : {0, 1}) { - shared.args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); + args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); + if (middle_in_lds_transpose == current_middle_in_lds_transpose) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } } log("Best MIDDLE_IN_LDS_TRANSPOSE is %u. Default MIDDLE_IN_LDS_TRANSPOSE is 1.\n", best_middle_in_lds_transpose); - shared.args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); + configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_IN_LDS_TRANSPOSE", best_middle_in_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); } // Find best MIDDLE_OUT_LDS_TRANSPOSE setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_middle_out_lds_transpose = 0; + u32 current_middle_out_lds_transpose = args->value("MIDDLE_OUT_LDS_TRANSPOSE", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 middle_out_lds_transpose : {0, 1}) { - shared.args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); + args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); + if (middle_out_lds_transpose == current_middle_out_lds_transpose) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } } log("Best MIDDLE_OUT_LDS_TRANSPOSE is %u. Default MIDDLE_OUT_LDS_TRANSPOSE is 1.\n", best_middle_out_lds_transpose); - shared.args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); + configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_OUT_LDS_TRANSPOSE", best_middle_out_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); } // Find best BIGLIT setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + if (time_FFTs) { + FFTConfig fft{*defaultShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_biglit = 0; + u32 current_biglit = args->value("BIGLIT", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 biglit : {0, 1}) { - shared.args->flags["BIGLIT"] = to_string(biglit); + args->flags["BIGLIT"] = to_string(biglit); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); log("Time for %12s using BIGLIT=%u is %6.1f\n", fft.spec().c_str(), biglit, cost); + if (biglit == current_biglit) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_biglit = biglit; } } log("Best BIGLIT is %u. Default BIGLIT is 1. The BIGLIT=0 option will probably be deprecated.\n", best_biglit); - shared.args->flags["BIGLIT"] = to_string(best_biglit); + configsUpdate(current_cost, best_cost, 0.003, "BIGLIT", best_biglit, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["BIGLIT"] = to_string(best_biglit); } - //GW: Time some IN/OUT_WG/SIZEX combos? + // Output new settings to config.txt + File config = File::openAppend("config.txt"); + if (newConfigKeyVals.size()) { + config.write("\n# New settings based on a -tune run.\n"); + for (u32 i = 0; i < newConfigKeyVals.size(); ++i) { + config.write(i == 0 ? " -use " : ","); + config.printf("%s=%u", newConfigKeyVals[i].first.c_str(), newConfigKeyVals[i].second); + } + config.write("\n"); + } + if (suggestedConfigKeyVals.size()) { + config.write("\n# These settings were slightly faster in a -tune run.\n"); + config.write("\n# It is suggested that each setting be timed over a longer duration to see if the setting really is faster.\n"); + for (u32 i = 0; i < suggestedConfigKeyVals.size(); ++i) { + config.write(i == 0 ? "# -use " : ","); + config.printf("%s=%u", suggestedConfigKeyVals[i].first.c_str(), suggestedConfigKeyVals[i].second); + } + config.write("\n"); + } + config.write("\n# Running two workers often gives better throughput."); + config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better.\n"); + config.write("# -workers 2 -use TAIL_KERNEL=3\n"); } // Flags that prune the amount of shapes and variants to time. @@ -623,13 +909,11 @@ void Tune::tune() { skip_some_WH_variants = 2; // should default be 1?? skip_1K_256 = 0; -//GW: Suggest tuning with TAIL_KERNELS=2 even if production runs use TAIL_KERNELS=3 - - // For each width, time the 001, 101, and 201 variants to find the fastest width variant. + // For each width, time the 001, 101, and 201 FP64 variants to find the fastest width variant. // In an ideal world we'd use the -time feature and look at the kCarryFused timing. Then we'd save this info in config.txt or tune.txt. map fastest_width_variants; - // For each height, time the 100, 101, and 102 variants to find the fastest height variant. + // For each height, time the 100, 101, and 102 FP64 variants to find the fastest height variant. // In an ideal world we'd use the -time feature and look at the tailSquare timing. Then we'd save this info in config.txt or tune.txt. map fastest_height_variants; @@ -638,12 +922,19 @@ skip_1K_256 = 0; // Loop through all possible FFT shapes for (const FFTShape& shape : shapes) { + // Skip some FFTs and NTTs + if (shape.fft_type == FFT64 && !time_FFTs) continue; + if (shape.fft_type != FFT64 && !time_NTTs) continue; + // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp()); // Loop through all possible variants for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { + // Only FP64 code supports variants + if (variant != LAST_VARIANT && !FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) continue; + // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. if (variant_W(variant) == 0) { if (!AMDGPU) continue; @@ -656,6 +947,14 @@ skip_1K_256 = 0; if (shape.height > 1024) continue; } + // Reject shapes that won't be used to test exponents in the user's desired range + { + FFTConfig fft{shape, variant, CARRY_AUTO}; + if (fft.maxExp() < min_exponent) continue; + if (fft.maxExp() > 2*max_exponent) continue; + if (shape.fft_type == FFT64 && fft.maxExp() > 1.2*max_exponent) continue; + } + // If only one shape was specified on the command line, time it. This lets the user time any shape, including non-favored ones. if (shapes.size() > 1) { @@ -667,7 +966,7 @@ skip_1K_256 = 0; // Skip variants where width or height are not using the fastest variant. // NOTE: We ought to offer a tune=option where we also test more accurate variants to extend the FFT's max exponent. - if (skip_some_WH_variants) { + if (skip_some_WH_variants && FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) { u32 fastest_width = 1; if (auto it = fastest_width_variants.find(shape.width); it != fastest_width_variants.end()) { fastest_width = it->second; @@ -689,7 +988,7 @@ skip_1K_256 = 0; FFTConfig{shape, variant, CARRY_32}.maxBpw() < FFTConfig{shape, variant_WMH (fastest_width, variant_M(variant), variant_H(variant)), CARRY_32}.maxBpw()) continue; } - if (skip_some_WH_variants) { + if (skip_some_WH_variants && FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) { u32 fastest_height = 1; if (auto it = fastest_height_variants.find(shape.height); it != fastest_height_variants.end()) { fastest_height = it->second; @@ -715,10 +1014,13 @@ skip_1K_256 = 0; //GW: If variant is specified on command line, time it (and only it)?? Or an option to only time one variant number?? - vector carryToTest{CARRY_32}; - // We need to test both carry-32 and carry-64 only when the carry transition is within the BPW range. - if (FFTConfig{shape, variant, CARRY_64}.maxBpw() > FFTConfig{shape, variant, CARRY_32}.maxBpw()) { - carryToTest.push_back(CARRY_64); + vector carryToTest{CARRY_AUTO}; + if (shape.fft_type == FFT64) { + carryToTest[0] = CARRY_32; + // We need to test both carry-32 and carry-64 only when the carry transition is within the BPW range. + if (FFTConfig{shape, variant, CARRY_64}.maxBpw() > FFTConfig{shape, variant, CARRY_32}.maxBpw()) { + carryToTest.push_back(CARRY_64); + } } for (auto carry : carryToTest) { @@ -729,10 +1031,11 @@ skip_1K_256 = 0; double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); bool isUseful = TuneEntry{cost, fft}.update(results); - log("%c %6.1f %12s %9u\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); + log("%c %6.1f %12s %9lu\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); } } } +//GW: write results more often (in case -tune run is aborted)? TuneEntry::writeTuneFile(results); } diff --git a/src/tune.h b/src/tune.h index de50bf52..a64f2100 100644 --- a/src/tune.h +++ b/src/tune.h @@ -22,8 +22,8 @@ class Tune { GpuCommon shared; Primes primes; - double maxBpw(FFTConfig fft); - double zForBpw(double bpw, FFTConfig fft, u32); + float maxBpw(FFTConfig fft); + float zForBpw(float bpw, FFTConfig fft, u32); public: Tune(Queue *q, GpuCommon shared) : q{q}, shared{shared} {} From 95778ecf5b88c0e6fc41cb318323a5970d5381c9 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 4 Oct 2025 02:36:20 +0000 Subject: [PATCH 036/115] Added help text for -tune options. --- src/Args.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 10201957..28814d06 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -194,8 +194,14 @@ named "config.txt" in the prpll run directory. -use DEBUG : enable asserts in OpenCL kernels (slow, developers) --tune : measures the speed of the FFTs specified in -fft to find the best FFT for each exponent. - +-tune : Looks for best settings to include in config.txt. Times many FFTs to find fastest one to test exponents -- written to tune.txt. + An -fft can be given on the command line to limit which FFTs are timed. + Options are not required. If present, the options are a comma separated list from below. + noconfig - Skip timings to find best config.txt settings + fp64 - Tune for settings that affect FP64 FFTs. Time FP64 FFTs for tune.txt. + ntt - Tune for settings that affect integer NTTs. Time integer NTTs for tune.txt. + minexp= - Time FFTs to find the best one for exponents greater than . + maxexp= - Time FFTs to find the best one for exponents less than . -device : select the GPU at position N in the list of devices -uid : select the GPU with the given UID (on ROCm/AMDGPU, Linux) -pci : select the GPU with the given PCI BDF, e.g. "0c:00.0" @@ -286,7 +292,7 @@ void Args::parse(const string& line) { log(" FFT | BPW | Max exp (M)\n"); for (const FFTShape& shape : FFTShape::multiSpec(s)) { for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { - if (variant != LAST_VARIANT && shape.fft_type != FFT64) continue; + if (variant != LAST_VARIANT && shape.fft_type != FFT64) continue; FFTConfig fft{shape, variant, CARRY_AUTO}; log("%12s | %.2f | %5.1f\n", fft.spec().c_str(), fft.maxBpw(), fft.maxExp() / 1'000'000.0); } From a13998ef4998c794a833e98a37ef6c63cb028b60 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 4 Oct 2025 23:42:46 +0000 Subject: [PATCH 037/115] Fixed bug in conversion from 64-bit CPU words to 32-bit GPU words --- src/Gpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index e2ce780b..7793715c 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1037,7 +1037,7 @@ void Gpu::writeWords(Buffer& buf, vector &words) { vector GPUdata; GPUdata.resize(words.size() / 2); for (u32 i = 0; i < words.size(); i += 2) { - GPUdata[i/2] = ((i64) words[i+1] << 32) | (i32) words[i]; + GPUdata[i/2] = ((i64) words[i+1] << 32) | (u32) words[i]; } buf.write(std::move(GPUdata)); } From f57a0badd9d03280ba9d6b2f16264baa95c38a1d Mon Sep 17 00:00:00 2001 From: george Date: Sun, 5 Oct 2025 04:35:07 +0000 Subject: [PATCH 038/115] Only tune variant 202 for FP32 (may change that later). --- src/tune.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index 774cccde..aef31bda 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -933,7 +933,7 @@ skip_1K_256 = 0; for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { // Only FP64 code supports variants - if (variant != LAST_VARIANT && !FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) continue; + if (variant != 202 && !FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) continue; // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. if (variant_W(variant) == 0) { From 1edd3e1ff1f1fc1b44f991572663c386b7b170d2 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 5 Oct 2025 05:50:56 +0000 Subject: [PATCH 039/115] Have LL tests obey the -log command line argument --- src/Gpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 7793715c..f6eec979 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1875,7 +1875,7 @@ LLResult Gpu::isPrimeLL(const Task& task) { log("Stopping, please wait..\n"); } - bool doLog = (k % 10'000 == 0) || doStop; + bool doLog = (k % args.logStep == 0) || doStop; bool leadOut = doLog || useLongCarry; squareLL(bufData, leadIn, leadOut); From 167a804eedc943c6f43c707996a0b7553882ff99 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 5 Oct 2025 14:47:47 +0000 Subject: [PATCH 040/115] Added -log 1000000 to -tune output --- src/tune.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/tune.cpp b/src/tune.cpp index aef31bda..094535ef 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -874,25 +874,27 @@ void Tune::tune() { // Output new settings to config.txt File config = File::openAppend("config.txt"); if (newConfigKeyVals.size()) { - config.write("\n# New settings based on a -tune run.\n"); + config.write("\n# New settings based on a -tune run."); for (u32 i = 0; i < newConfigKeyVals.size(); ++i) { - config.write(i == 0 ? " -use " : ","); + config.write(i == 0 ? "\n -use " : ","); config.printf("%s=%u", newConfigKeyVals[i].first.c_str(), newConfigKeyVals[i].second); } config.write("\n"); } if (suggestedConfigKeyVals.size()) { - config.write("\n# These settings were slightly faster in a -tune run.\n"); - config.write("\n# It is suggested that each setting be timed over a longer duration to see if the setting really is faster.\n"); + config.write("\n# These settings were slightly faster in a -tune run."); + config.write("\n# It is suggested that each setting be timed over a longer duration to see if the setting really is faster."); for (u32 i = 0; i < suggestedConfigKeyVals.size(); ++i) { - config.write(i == 0 ? "# -use " : ","); + config.write(i == 0 ? "\n# -use " : ","); config.printf("%s=%u", suggestedConfigKeyVals[i].first.c_str(), suggestedConfigKeyVals[i].second); } config.write("\n"); } + config.write("\n# Less frequent save file creation improves throughput."); + config.write("\n -log 1000000\n"); config.write("\n# Running two workers often gives better throughput."); - config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better.\n"); - config.write("# -workers 2 -use TAIL_KERNEL=3\n"); + config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better."); + config.write("\n# -workers 2 -use TAIL_KERNEL=3\n"); } // Flags that prune the amount of shapes and variants to time. From 5a5748e8b48053ec418abea50b1163d62e209602 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 6 Oct 2025 01:12:07 +0000 Subject: [PATCH 041/115] Fixed typo in -workers 2 suggestion --- src/tune.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index 094535ef..7c97d62f 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -894,7 +894,7 @@ void Tune::tune() { config.write("\n -log 1000000\n"); config.write("\n# Running two workers often gives better throughput."); config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better."); - config.write("\n# -workers 2 -use TAIL_KERNEL=3\n"); + config.write("\n# -workers 2 -use TAIL_KERNELS=3\n"); } // Flags that prune the amount of shapes and variants to time. From 95cbce3faa7c70c060496c6b1b2d03f3730800fc Mon Sep 17 00:00:00 2001 From: george Date: Tue, 7 Oct 2025 04:01:22 +0000 Subject: [PATCH 042/115] Fixed bug in GF61 csqq that I don't think affected any existing code paths --- src/cl/math.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 18951301..0d26d3c1 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -731,7 +731,7 @@ GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } // Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). GF61 OVERLOAD csqq(GF61 a, const u32 m61_count) { - if (m61_count > 4) a = modM61(a); + if (m61_count > 4) return csqq(modM61(a), 2); Z61 re = weakMul(a.x + a.y, a.x + neg(a.y, m61_count), 2 * m61_count - 1, 2 * m61_count); Z61 im = weakMul(a.x + a.x, a.y, 2 * m61_count - 1, m61_count); return U2(re, im); From c5d2ce4a7701030637b90c5f4f18a0fe58db8f6d Mon Sep 17 00:00:00 2001 From: george Date: Tue, 7 Oct 2025 08:55:27 +0000 Subject: [PATCH 043/115] Improved FP32+GF61 carry propagation. Minor tweaks to all FP carry propagations. --- src/cl/carryutil.cl | 61 +++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index bbe291f2..f8bd56c3 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -165,7 +165,7 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r double d = fma(u, invWeight, RNDVALCarry); // Optionally calculate roundoff error - float roundoff = fabs((float) fma(u, -invWeight, d - RNDVALCarry)); + float roundoff = fabs((float) fma(u, invWeight, RNDVALCarry - d)); *maxROE = max(*maxROE, roundoff); // Convert to long (for CARRY32 case we don't need to strip off the RNDVAL bits) @@ -207,7 +207,7 @@ i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_r float d = fma(u, invWeight, RNDVALCarry); // Optionally calculate roundoff error - float roundoff = fabs(fma(u, -invWeight, d - RNDVALCarry)); + float roundoff = fabs(fma(u, invWeight, RNDVALCarry - d)); *maxROE = max(*maxROE, roundoff); // Convert to int @@ -300,16 +300,14 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, u32 n31 = get_Z31(u31); // The final result must be n31 mod M31. Use FP64 data to calculate this value. - u = u * invWeight - (double) n31; // This should be close to a multiple of M31 - u *= 4.656612875245796924105750827168e-10; // Divide by M31. Could divide by 2^31 (0.0000000004656612873077392578125) be accurate enough? //GWBUG - check the generated code! Use 1/M31??? + u = fma(u, invWeight, - (double) n31); // This should be close to a multiple of M31 + double uInt = fma(u, 4.656612875245796924105750827168e-10, RNDVAL); // Divide by M31 and round to int + i64 n64 = RNDVALdoubleToLong(uInt); - i64 n64 = RNDVALdoubleToLong(u + RNDVAL); - - i128 v = ((i128) n64 << 31) - n64; // n64 * M31 - v += n31; + i128 v = (((i128) n64 << 31) | n31) - n64; // n64 * M31 + n31 // Optionally calculate roundoff error - float roundoff = (float) fabs(u - (double) n64); + float roundoff = (float) fabs(fma(u, 4.656612875245796924105750827168e-10, RNDVAL - uInt)); *maxROE = max(*maxROE, roundoff); // Mul by 3 and add carry @@ -337,16 +335,14 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, u32 n31 = get_Z31(u31); // The final result must be n31 mod M31. Use FP32 data to calculate this value. - uF2 = uF2 * F2_invWeight - (float) n31; // This should be close to a multiple of M31 - uF2 *= 0.0000000004656612873077392578125f; // Divide by 2^31 //GWBUG - check the generated code! - - i32 nF2 = lowBits(as_int(uF2 + RNDVAL), 22); + uF2 = fma(uF2, F2_invWeight, - (float) n31); // This should be close to a multiple of M31 + float uF2int = fma(uF2, 0.0000000004656612873077392578125f, RNDVAL); // Divide by 2^31 + i32 nF2 = RNDVALfloatToInt(uF2int); - i64 v = ((i64) nF2 << 31) - nF2; // nF2 * M31 - v += n31; + i64 v = (((i64) nF2 << 31) | n31) - nF2; // nF2 * M31 + n31 // Optionally calculate roundoff error - float roundoff = fabs(uF2 - nF2); + float roundoff = fabs(fma(uF2, 0.0000000004656612873077392578125f, RNDVAL - uF2int)); *maxROE = max(*maxROE, roundoff); // Mul by 3 and add carry @@ -371,39 +367,16 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, u61 = shr(u61, m61_invWeight); u64 n61 = get_Z61(u61); -#if 0 -BUG - need more than 64 bit integers - - // The final result must be n61 mod M61. Use FP32 data to calculate this value. - uF2 = uF2 * F2_invWeight - (float) n61; // This should be close to a multiple of M61 - uF2 *= 4.3368086899420177360298112034798e-19f; // Divide by 2^61 //GWBUG - check the generated code! - - i32 nF2 = lowBits(as_int(uF2 + RNDVAL), 22); - - i64 v = ((i64) nF2 << 61) - nF2; // nF2 * M61 - v += n61; - - // Optionally calculate roundoff error - float roundoff = fabs(uF2 - (float) nF2); - *maxROE = max(*maxROE, roundoff); -#else - // The final result must be n61 mod M61. Use FP32 data to calculate this value. -#undef RNDVAL //GWBUG - why are we using doubles? -#define RNDVAL (3.0 * (1l << 51)) - double uuF2 = (double) uF2 * (double) F2_invWeight - (double) n61; // This should be close to a multiple of M61 - uuF2 = uuF2 * 4.3368086899420177360298112034798e-19; // Divide by 2^61 //GWBUG - check the generated code! -volatile double xxF2 = uuF2 + RNDVAL; // Divide by 2^61 //GWBUG - check the generated code! - xxF2 -= RNDVAL; - i32 nF2 = (int) xxF2; + uF2 = fma(uF2, F2_invWeight, - (float) n61); // This should be close to a multiple of M61 + float uF2int = fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL); // Divide by 2^61 and round to int + i32 nF2 = RNDVALfloatToInt(uF2int); - i128 v = ((i128) nF2 << 61) - nF2; // nF2 * M61 - v += n61; + i128 v = (((i128) nF2 << 61) | n61) - nF2; // nF2 * M61 + n61 // Optionally calculate roundoff error - float roundoff = (float) fabs(uuF2 - (double) nF2); + float roundoff = fabs(fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL - uF2int)); *maxROE = max(*maxROE, roundoff); -#endif // Mul by 3 and add carry #if MUL3 From 15e4737c02c7632fdc870ed6766114f063d6d38e Mon Sep 17 00:00:00 2001 From: george Date: Tue, 7 Oct 2025 18:25:52 +0000 Subject: [PATCH 044/115] Remove extraneous comma after ROEavg output. Dropped logging ROE info if ROEmax < 0.005. --- src/Gpu.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index f6eec979..fea54b42 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1263,8 +1263,12 @@ void Gpu::doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u3 auto [roeSq, roeMul] = readROE(); double z = roeSq.z(); zAvg.update(z, roeSq.N); - log("%sZ=%.0f (avg %.1f), ROEmax=%.3f, ROEavg=%.3f, %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), - z, zAvg.avg(), roeSq.max, roeSq.mean, (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); + if (roeSq.max > 0.005) + log("%sZ=%.0f (avg %.1f), ROEmax=%.3f, ROEavg=%.3f. %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), + z, zAvg.avg(), roeSq.max, roeSq.mean, (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); + else + log("%sZ=%.0f (avg %.1f) %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), + z, zAvg.avg(), (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); if (roeSq.N > 2 && z < 20) { log("Danger ROE! Z=%.1f is too small, increase precision or FFT size!\n", z); From ee46f6dee4df352c7dd436451c738ea9113bc929 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 7 Oct 2025 21:23:40 +0000 Subject: [PATCH 045/115] Save one and instruction in CUDA compile of M31 + M61 carry propagation --- src/cl/carryutil.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index f8bd56c3..a4ba1c5a 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -412,7 +412,7 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 i128 v = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 // Convert to balanced representation by subtracting M61*M31 - if ((v >> 64) & 0xF8000000) v = v - (i128) M31 * (i128) M61; + if (((u32)(v >> 64)) & 0xF8000000) v = v - (i128) M31 * (i128) M61; // Optionally calculate roundoff error as proximity to M61*M31/2. 27 bits of accuracy should be sufficient. u32 roundoff = (u32) abs((i32)(v >> 64)); From 90ba20d8687ae15af7c4bb1f7ff558680f6cde58 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 8 Oct 2025 08:06:08 +0000 Subject: [PATCH 046/115] Don't output -log and -workers tune suggestions if they are already set. --- src/tune.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/tune.cpp b/src/tune.cpp index 7c97d62f..30be06da 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -890,11 +890,15 @@ void Tune::tune() { } config.write("\n"); } - config.write("\n# Less frequent save file creation improves throughput."); - config.write("\n -log 1000000\n"); - config.write("\n# Running two workers often gives better throughput."); - config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better."); - config.write("\n# -workers 2 -use TAIL_KERNELS=3\n"); + if (args->logStep < 100000) { + config.write("\n# Less frequent save file creation improves throughput."); + config.write("\n -log 1000000\n"); + } + if (args->workers < 2) { + config.write("\n# Running two workers often gives better throughput."); + config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better."); + config.write("\n# -workers 2 -use TAIL_KERNELS=3\n"); + } } // Flags that prune the amount of shapes and variants to time. From 6ea19b04d8b89dda816d456a573581efb6bad4b6 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 9 Oct 2025 00:48:31 +0000 Subject: [PATCH 047/115] For completeness, added tthe hybrid FP32+GF31+GF61 FFT. It may not be very useful as carryFused width=512 is under severe register pressure. --- src/FFTConfig.cpp | 22 ++-- src/FFTConfig.h | 2 +- src/Gpu.cpp | 28 ++-- src/cl/base.cl | 2 +- src/cl/carry.cl | 102 ++++++++++++++- src/cl/carryfused.cl | 297 ++++++++++++++++++++++++++++++++++++++++++- src/cl/carryinc.cl | 37 +++++- src/cl/carryutil.cl | 100 +++++++++++++-- src/cl/fftp.cl | 101 ++++++++++++++- src/fftbpw.h | 91 +++++++------ 10 files changed, 702 insertions(+), 80 deletions(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index 09c945ae..09b07562 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -83,7 +83,7 @@ vector FFTShape::multiSpec(const string& iniSpec) { vector FFTShape::allShapes(u32 sizeFrom, u32 sizeTo) { vector configs; - for (enum FFT_TYPES type : {FFT64, FFT3161, FFT3261, FFT61}) { + for (enum FFT_TYPES type : {FFT64, FFT3161, FFT3261, FFT61, FFT323161}) { for (u32 width : {256, 512, 1024, 4096}) { for (u32 height : {256, 512, 1024}) { if (width == 256 && height == 1024) { continue; } // Skip because we prefer width >= height @@ -235,18 +235,16 @@ FFTConfig::FFTConfig(FFTShape shape, u32 variant, u32 carry) : assert(variant_M(variant) < N_VARIANT_M); assert(variant_H(variant) < N_VARIANT_H); - if (shape.fft_type == FFT64) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 0; - else if (shape.fft_type == FFT3161) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 1; - else if (shape.fft_type == FFT3261) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 1; - else if (shape.fft_type == FFT61) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 1; - else if (shape.fft_type == FFT3231) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 0; - else if (shape.fft_type == FFT6431) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0; - else if (shape.fft_type == FFT31) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0; - else if (shape.fft_type == FFT32) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 0; + if (shape.fft_type == FFT64) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT3161) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT3261) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT61) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 4; + else if (shape.fft_type == FFT323161) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT3231) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT6431) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 8; + else if (shape.fft_type == FFT31) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT32) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; else throw "FFT type"; - - if ((FFT_FP64 && NTT_GF31) || (NTT_GF31 && NTT_GF61) || (FFT_FP32 && NTT_GF61)) WordSize = 8; - else WordSize = 4; } string FFTConfig::spec() const { diff --git a/src/FFTConfig.h b/src/FFTConfig.h index b737d20b..ece9fcec 100644 --- a/src/FFTConfig.h +++ b/src/FFTConfig.h @@ -21,7 +21,7 @@ string numberK(u64 n); using KeyVal = std::pair; -enum FFT_TYPES {FFT64=0, FFT3161=1, FFT3261=2, FFT61=3, FFT3231=50, FFT6431=51, FFT31=52, FFT32=53}; +enum FFT_TYPES {FFT64=0, FFT3161=1, FFT3261=2, FFT61=3, FFT323161=4, FFT3231=50, FFT6431=51, FFT31=52, FFT32=53}; class FFTShape { public: diff --git a/src/Gpu.cpp b/src/Gpu.cpp index fea54b42..d18dbaae 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -140,7 +140,7 @@ Weights genWeights(FFTConfig fft, u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { memcpy((double *) weightsIF.data(), weightsIF32.data(), weightsIF32.size() * sizeof(float)); } - if (fft.FFT_FP64 || fft.FFT_FP64) { + if (fft.FFT_FP64 || fft.FFT_FP32) { for (u32 line = 0; line < H; ++line) { for (u32 thread = 0; thread < groupWidth; ) { std::bitset<32> b; @@ -342,44 +342,56 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< // The openCL code needs to know the offset to the data and trig values. Distances are in "number of double2 values". if (fft.FFT_FP64 && fft.NTT_GF31) { // GF31 data is located after the FP64 data. Compute size of the FP64 data and trigs. - defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); } + else if (fft.FFT_FP32 && fft.NTT_GF31 && fft.NTT_GF61) { + // GF31 and GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. + u32 sz1, sz2, sz3, sz4; + defines += toDefine("DISTGF31", sz1 = FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", sz2 = SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", sz3 = MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", sz4 = SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTGF61", sz1 + GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", sz2 + SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", sz3 + MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", sz4 + SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } else if (fft.FFT_FP32 && fft.NTT_GF31) { // GF31 data is located after the FP32 data. Compute size of the FP32 data and trigs. - defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); } else if (fft.FFT_FP32 && fft.NTT_GF61) { // GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. - defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); defines += toDefine("DISTWTRIGGF61", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); } else if (fft.NTT_GF31 && fft.NTT_GF61) { - defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTGF31", 0); defines += toDefine("DISTWTRIGGF31", 0); defines += toDefine("DISTMTRIGGF31", 0); defines += toDefine("DISTHTRIGGF31", 0); // GF61 data is located after the GF31 data. Compute size of the GF31 data and trigs. - defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); defines += toDefine("DISTWTRIGGF61", SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); } else if (fft.NTT_GF31) { - defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTGF31", 0); defines += toDefine("DISTWTRIGGF31", 0); defines += toDefine("DISTMTRIGGF31", 0); defines += toDefine("DISTHTRIGGF31", 0); } else if (fft.NTT_GF61) { - defines += toDefine("DISTGF61", 0); + defines += toDefine("DISTGF61", 0); defines += toDefine("DISTWTRIGGF61", 0); defines += toDefine("DISTMTRIGGF61", 0); defines += toDefine("DISTHTRIGGF61", 0); diff --git a/src/cl/base.cl b/src/cl/base.cl index 24986309..7acd0fe2 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -184,7 +184,7 @@ typedef ulong2 GF61; // A complex value using two Z61s. For a GF(M61^2) // Typedefs for "combo" FFT/NTTs (multiple NTT primes or hybrid FFT/NTT). #define COMBO_FFT (FFT_FP64 + FFT_FP32 + NTT_GF31 + NTT_GF61 > 1) // Sanity check for supported FFT/NTT -#if (FFT_FP64 & NTT_GF31 & !FFT_FP32 & !NTT_GF61) | (NTT_GF31 & NTT_GF61 & !FFT_FP64 & !FFT_FP32) | (FFT_FP32 & NTT_GF61 & !FFT_FP64 & !NTT_GF31) +#if (FFT_FP64 & NTT_GF31 & !FFT_FP32 & !NTT_GF61) | (NTT_GF31 & NTT_GF61 & !FFT_FP64 & !FFT_FP32) | (FFT_FP32 & NTT_GF61 & !FFT_FP64 & !NTT_GF31) | (FFT_FP32 & NTT_GF31 & NTT_GF61 & !FFT_FP64) #elif !COMBO_FFT | (FFT_FP32 & NTT_GF31 & !FFT_FP64 & !NTT_GF61) #else error - unsupported FFT/NTT combination diff --git a/src/cl/carry.cl b/src/cl/carry.cl index dfb32fb2..26bc6e91 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -320,7 +320,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 +#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); @@ -399,7 +399,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF61 +#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); @@ -478,7 +478,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF31 & NTT_GF61 +#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { u32 g = get_group_id(0); @@ -566,6 +566,102 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u } +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 & NTT_GF61 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(F2) inF2 = (CP(F2)) in; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = (m31_weight_shift + log2_NWORDS + 1) % 31; + m61_weight_shift = (m61_weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP32 and second GF31 and GF61 weight shift + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in31[p]), SWAP_XY(in61[p]), w1, w2, m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; +// GWBUG - derive m61 weight shifts from m31 counter (or vice versa) sort of easily done from difference in the two weight shifts (no need to add frac_bits twice) + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + #else error - missing Carry kernel implementation #endif diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index 9fcb6da6..cdacc841 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -1106,7 +1106,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 +#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1352,7 +1352,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF61 +#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1598,7 +1598,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF31 & NTT_GF61 +#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1853,6 +1853,297 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( } +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 & NTT_GF61 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF61 lds61[WIDTH / 4]; +#else + local GF61 lds61[WIDTH / 2]; +#endif + local F2 *ldsF2 = (local F2 *) lds61; + local GF31 *lds31 = (local GF31 *) lds61; + + F2 uF2[NW]; + GF31 u31[NW]; + GF61 u61[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(inF2, uF2, line); + readCarryFusedLine(in31, u31, line); + readCarryFusedLine(in61, u61, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(ldsF2 + zerohack, uF2, smallTrigF2 + zerohack); + bar(); + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); + bar(); + new_fft_WIDTH1(lds61 + zerohack, u61, smallTrig61 + zerohack); +#else + new_fft_WIDTH1(ldsF2, uF2, smallTrigF2); + bar(); + new_fft_WIDTH1(lds31, u31, smallTrig31); + bar(); + new_fft_WIDTH1(lds61, u61, smallTrig61); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * H * 2 - 1) * m31_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + u64 m31_starting_combo_counter = m31_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * H * 2 - 1) * m61_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + u64 m61_starting_combo_counter = m61_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = m31_weight_shift + log2_NWORDS + 1; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_weight_shift = m61_weight_shift + log2_NWORDS + 1; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP32 weights and second GF31 and GF61 weight shift + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u31[i]), SWAP_XY(u61[i]), invWeight1, invWeight2, m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_bigstep; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + } + m31_combo_counter = m31_starting_combo_counter; // Restore starting counter for applying weights after carry propagation + m61_combo_counter = m61_starting_combo_counter; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + uF2[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + uF2[i] = U2(uF2[i].x * wu[i].x, uF2[i].y * wu[i].y); + u31[i] = U2(shl(make_Z31(wu[i].x), m31_weight_shift0), shl(make_Z31(wu[i].y), m31_weight_shift1)); + u61[i] = U2(shl(make_Z61(wu[i].x), m61_weight_shift0), shl(make_Z61(wu[i].y), m61_weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_bigstep; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + } + + bar(); + + new_fft_WIDTH2(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, line); + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); + + bar(); + + new_fft_WIDTH2(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, line); +} + + #else error - missing CarryFused kernel implementation #endif diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index 67d75565..a8d36307 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -165,7 +165,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, GF31 u31, T invWeight1, T invWeigh /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 +#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -197,7 +197,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, F invWeight1, F invWei /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF61 +#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -229,7 +229,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWei /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF31 & NTT_GF61 +#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -256,6 +256,37 @@ Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, return (Word2) (a, b); } +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 & NTT_GF61 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + u32 m61_invWeight1, u32 m61_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + iCARRY midCarry; + i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + u32 m61_invWeight1, u32 m61_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + iCARRY midCarry; + i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + #else error - missing weightAndCarryPair implementation #endif diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index a4ba1c5a..97bb48df 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -323,7 +323,7 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 +#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 #define SLOPPY_MAXBPW 154 // Based on 138M expo in 8M FFT = 16.45 BPW @@ -336,13 +336,13 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, // The final result must be n31 mod M31. Use FP32 data to calculate this value. uF2 = fma(uF2, F2_invWeight, - (float) n31); // This should be close to a multiple of M31 - float uF2int = fma(uF2, 0.0000000004656612873077392578125f, RNDVAL); // Divide by 2^31 + float uF2int = fma(uF2, 4.656612875245796924105750827168e-10f, RNDVAL); // Divide by M31 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); i64 v = (((i64) nF2 << 31) | n31) - nF2; // nF2 * M31 + n31 // Optionally calculate roundoff error - float roundoff = fabs(fma(uF2, 0.0000000004656612873077392578125f, RNDVAL - uF2int)); + float roundoff = fabs(fma(uF2, 4.656612875245796924105750827168e-10f, RNDVAL - uF2int)); *maxROE = max(*maxROE, roundoff); // Mul by 3 and add carry @@ -356,7 +356,7 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF61 +#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 #define SLOPPY_MAXBPW 309 // Based on 134M expo in 4M FFT = 31.95 BPW @@ -391,7 +391,7 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF31 & NTT_GF61 +#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 #define SLOPPY_MAXBPW 383 // Based on 165M expo in 4M FFT = 39.34 BPW @@ -423,7 +423,7 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 v = v * 3; #endif v = v + inCarry; - i96 value = make_i96((u64) (v >> 32), (u32) v); + i96 value = make_i96(v); #else @@ -447,6 +447,50 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 return value; } +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 & NTT_GF61 + +#define SLOPPY_MAXBPW 461 // Based on 198M expo in 4M FFT = 47.20 BPW + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { + + // Apply inverse weights + u31 = shr(u31, m31_invWeight); + u61 = shr(u61, m61_invWeight); + + // Use chinese remainder theorem to create a 92-bit result. Loosely copied from Yves Gallot's mersenne2 program. + u32 n31 = get_Z31(u31); + u61 = subq(u61, make_Z61(n31), 2); // u61 - u31 + u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) + u64 n61 = get_Z61(u61); + i128 n3161 = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 + +//GW - computing (float)n3161 may be complicated and we don't need all that precision. . Would forming a smaller value (like just n61) to float be faster with the *M31 taken into account with the constant? + + // The final result must be n3161 mod M31*M61. Use FP32 data to calculate this value. + uF2 = fma(uF2, F2_invWeight, - (float) n3161); // This should be close to a multiple of M31*M61 + float uF2int = fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL); // Divide by M31*M61 and round to int + i32 nF2 = RNDVALfloatToInt(uF2int); + + i64 nF2m31 = ((i64) nF2 << 31) - nF2; // nF2 * M31 + i128 v = ((i128) nF2m31 << 61) - nF2m31 + n3161; // nF2m31 * M61 + n3161 + + // Optionally calculate roundoff error + float roundoff = fabs(fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL - uF2int)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + v = v * 3; +#endif + v = v + inCarry; + return v; +} + #else error - missing weightAndCarryOne implementation #endif @@ -456,8 +500,14 @@ error - missing weightAndCarryOne implementation /* Split a value + carryIn into a big-or-little word and a carryOut */ /************************************************************************/ +Word OVERLOAD carryStep(i128 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + i64 w = lowBits((i64)x, nBits); + *outCarry = (u64)(x >> nBits) + (w < 0); + return w; +} + Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { - const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); //GWBUG - is this ever faster? Not on TitanV @@ -552,6 +602,19 @@ Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { /* CarryFinal will later turn this into a balanced signed value. */ /*****************************************************************/ +Word OVERLOAD carryStepUnsignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + u64 w = ulowFixedBits((u64)x, bigwordBits); + const i128 topbitmask = ~((i128)1 << (bigwordBits - 1)); +//GW Can we use unsigned shift (knowing the sign won't be lost due to truncating the result) -- this is really a 64-bit extract (or two 32-bit extrats) -- use elsewhere?) + *outCarry = (x & topbitmask) >> nBits; +//GW use this style else where, check for more fixed low bits + return w; +} + Word OVERLOAD carryStepUnsignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); @@ -560,12 +623,12 @@ Word OVERLOAD carryStepUnsignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { #if EXP / NWORDS >= 32 // nBits is 32 or more i64 xhi = i96_hi64(x) & ~((1ULL << (bigwordBits - 32)) - 1); *outCarry = xhi >> (nBits - 32); - return ulowBits(i96_lo64(x), bigwordBits); + return ulowFixedBits(i96_lo64(x), bigwordBits); #elif EXP / NWORDS == 31 // nBits = 31 or 32 *outCarry = i96_hi64(x) << (32 - nBits); return i96_lo32(x); // ulowBits(x, bigwordBits = 32); #else // nBits less than 32 - u32 w = ulowBits(i96_lo32(x), bigwordBits); + u32 w = ulowFixedBits(i96_lo32(x), bigwordBits); *outCarry = (i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) - w) >> nBits); return w; #endif @@ -595,6 +658,23 @@ Word OVERLOAD carryStepUnsignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { /* Also used on first word in carryFinal when not near max BPW. */ /**********************************************************************/ +Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { +#if EXP > NWORDS / 10 * SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); + u64 xlo = (u64)x; + u64 xlo_topbit = xlo & (1ULL << (bigwordBits - 1)); + i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; + *outCarry = (x + xlo_topbit) >> nBits; + return w; +#endif +} + + Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { #if EXP > NWORDS / 10 * SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); @@ -619,7 +699,7 @@ Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { *outCarry = (i96_hi64(x) + (w < 0)) << (32 - nBits); return w; #else // nBits less than 32 //GWBUG - is there a faster version? Is this faster than plain old carryStep? - i32 w = lowBits(i96_lo32(x), bigwordBits); + i32 w = lowFixedBits(i96_lo32(x), bigwordBits); *outCarry = (((i96_hi64(x) << (32 - bigwordBits)) | (i96_lo32(x) >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); return w; #endif diff --git a/src/cl/fftp.cl b/src/cl/fftp.cl index eafebfe8..6a0f65fe 100644 --- a/src/cl/fftp.cl +++ b/src/cl/fftp.cl @@ -248,7 +248,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 +#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { @@ -318,7 +318,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIG /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF61 +#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { @@ -388,7 +388,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIG /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF31 & NTT_GF61 +#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { @@ -467,6 +467,101 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { } +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_FP32 & NTT_GF31 & NTT_GF61 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local GF61 lds61[WIDTH / 2]; + local F2 *ldsF2 = (local F2 *) lds61; + local GF31 *lds31 = (local GF31 *) lds61; + F2 uF2[NW]; + GF31 u31[NW]; + GF61 u61[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m31_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m61_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP32 weights and the second GF31 and GF61 weight shift + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + u32 m61_weight_shift1 = m61_weight_shift; + // Convert and weight input + uF2[i] = U2(in[p].x * w1, in[p].y * w2); + u31[i] = U2(shl(make_Z31(in[p].x), m31_weight_shift0), shl(make_Z31(in[p].y), m31_weight_shift1)); // Form a GF31 from each pair of input words + u61[i] = U2(shl(make_Z61(in[p].x), m61_weight_shift0), shl(make_Z61(in[p].y), m61_weight_shift1)); // Form a GF61 from each pair of input words + +// Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + if (m31_weight_shift > 31) m31_weight_shift -= 31; + m61_combo_counter += m61_combo_bigstep; + if (m61_weight_shift > 61) m61_weight_shift -= 61; + } + + fft_WIDTH(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, g); + bar(); + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); + bar(); + fft_WIDTH(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, g); +} + + #else error - missing FFTp kernel implementation #endif diff --git a/src/fftbpw.h b/src/fftbpw.h index 928f45be..86392e3d 100644 --- a/src/fftbpw.h +++ b/src/fftbpw.h @@ -111,24 +111,24 @@ { "1:4K:16:512", {38.94, 38.94, 38.94, 38.94, 38.94, 38.94}}, { "1:4K:16:1K", {38.84, 38.84, 38.84, 38.84, 38.84, 38.84}}, // FFT3261 -{ "2:256:2:256", {32.05, 32.35, 32.35, 32.35, 32.35, 32.35}}, -{ "2:256:4:256", {31.95, 32.25, 32.25, 32.25, 32.25, 32.25}}, -{ "2:256:8:256", {31.85, 32.15, 32.15, 32.15, 32.15, 32.15}}, -{ "2:512:4:256", {31.85, 32.15, 32.15, 32.15, 32.15, 32.15}}, -{"2:256:16:256", {31.75, 32.05, 32.05, 32.05, 32.05, 32.05}}, -{ "2:512:8:256", {31.75, 32.05, 32.05, 32.05, 32.05, 32.05}}, -{ "2:512:4:512", {31.75, 32.05, 32.05, 32.05, 32.05, 32.05}}, -{ "2:1K:8:256", {31.65, 31.95, 31.95, 31.95, 31.95, 31.95}}, -{"2:512:16:256", {31.65, 31.95, 31.95, 31.95, 31.95, 31.95}}, -{ "2:512:8:512", {31.65, 31.95, 31.95, 31.95, 31.95, 31.95}}, -{ "2:1K:16:256", {31.55, 31.85, 31.85, 31.85, 31.85, 31.85}}, -{ "2:1K:8:512", {31.55, 31.85, 31.85, 31.85, 31.85, 31.85}}, -{"2:512:16:512", {31.55, 31.85, 31.85, 31.85, 31.85, 31.85}}, -{ "2:1K:16:512", {31.45, 31.75, 31.75, 31.75, 31.75, 31.75}}, -{ "2:1K:8:1K", {31.45, 31.75, 31.75, 31.75, 31.75, 31.75}}, -{ "2:1K:16:1K", {31.35, 31.65, 31.65, 31.65, 31.65, 31.65}}, -{ "2:4K:16:512", {31.25, 31.55, 31.55, 31.55, 31.55, 31.55}}, -{ "2:4K:16:1K", {31.15, 31.45, 31.45, 31.45, 31.45, 31.45}}, +{ "2:256:2:256", {32.05, 32.05, 32.05, 32.05, 32.05, 32.05}}, +{ "2:256:4:256", {31.95, 31.95, 31.95, 31.95, 31.95, 31.95}}, +{ "2:256:8:256", {31.85, 31.85, 31.85, 31.85, 31.85, 31.85}}, +{ "2:512:4:256", {31.85, 31.85, 31.85, 31.85, 31.85, 31.85}}, +{"2:256:16:256", {31.75, 31.75, 31.75, 31.75, 31.75, 31.75}}, +{ "2:512:8:256", {31.75, 31.75, 31.75, 31.75, 31.75, 31.75}}, +{ "2:512:4:512", {31.75, 31.75, 31.75, 31.75, 31.75, 31.75}}, +{ "2:1K:8:256", {31.65, 31.65, 31.65, 31.65, 31.65, 31.65}}, +{"2:512:16:256", {31.65, 31.65, 31.65, 31.65, 31.65, 31.65}}, +{ "2:512:8:512", {31.65, 31.65, 31.65, 31.65, 31.65, 31.65}}, +{ "2:1K:16:256", {31.55, 31.55, 31.55, 31.55, 31.55, 31.55}}, +{ "2:1K:8:512", {31.55, 31.55, 31.55, 31.55, 31.55, 31.55}}, +{"2:512:16:512", {31.55, 31.55, 31.55, 31.55, 31.55, 31.55}}, +{ "2:1K:16:512", {31.45, 31.45, 31.45, 31.45, 31.45, 31.45}}, +{ "2:1K:8:1K", {31.45, 31.45, 31.45, 31.45, 31.45, 31.45}}, +{ "2:1K:16:1K", {31.35, 31.35, 31.35, 31.35, 31.35, 31.35}}, +{ "2:4K:16:512", {31.25, 31.25, 31.25, 31.25, 31.25, 31.25}}, +{ "2:4K:16:1K", {31.15, 31.15, 31.15, 31.15, 31.15, 31.15}}, // FFT61 { "3:256:2:256", {24.20, 24.20, 24.20, 24.20, 24.20, 24.20}}, { "3:256:4:256", {24.10, 24.10, 24.10, 24.10, 24.10, 24.10}}, @@ -148,25 +148,44 @@ { "3:1K:16:1K", {23.50, 23.50, 23.50, 23.50, 23.50, 23.50}}, { "3:4K:16:512", {23.40, 23.40, 23.40, 23.40, 23.40, 23.40}}, { "3:4K:16:1K", {23.30, 23.30, 23.30, 23.30, 23.30, 23.30}}, +// FFT323161 +{ "4:256:2:256", {47.65, 47.65, 47.65, 47.65, 47.65, 47.65}}, +{ "4:256:4:256", {47.55, 47.55, 47.55, 47.55, 47.55, 47.55}}, +{ "4:256:8:256", {47.45, 47.45, 47.45, 47.45, 47.45, 47.45}}, +{ "4:512:4:256", {47.45, 47.45, 47.45, 47.45, 47.45, 47.45}}, +{"4:256:16:256", {47.35, 47.35, 47.35, 47.35, 47.35, 47.35}}, +{ "4:512:8:256", {47.35, 47.35, 47.35, 47.35, 47.35, 47.35}}, +{ "4:512:4:512", {47.35, 47.35, 47.35, 47.35, 47.35, 47.35}}, +{ "4:1K:8:256", {47.25, 47.25, 47.25, 47.25, 47.25, 47.25}}, +{"4:512:16:256", {47.25, 47.25, 47.25, 47.25, 47.25, 47.25}}, +{ "4:512:8:512", {47.25, 47.25, 47.25, 47.25, 47.25, 47.25}}, +{ "4:1K:16:256", {47.15, 47.15, 47.15, 47.15, 47.15, 47.15}}, +{ "4:1K:8:512", {47.15, 47.15, 47.15, 47.15, 47.15, 47.15}}, +{"4:512:16:512", {47.15, 47.15, 47.15, 47.15, 47.15, 47.15}}, +{ "4:1K:16:512", {47.05, 47.05, 47.05, 47.05, 47.05, 47.05}}, +{ "4:1K:8:1K", {47.05, 47.05, 47.05, 47.05, 47.05, 47.05}}, +{ "4:1K:16:1K", {46.95, 46.95, 46.95, 46.95, 46.95, 46.95}}, +{ "4:4K:16:512", {46.85, 46.85, 46.85, 46.85, 46.85, 46.85}}, +{ "4:4K:16:1K", {46.75, 46.75, 46.75, 46.75, 46.75, 46.75}}, // FFT3231 -{ "50:256:2:256", {16.95, 34.26, 34.26, 34.26, 34.26, 34.26}}, -{ "50:256:4:256", {16.85, 34.16, 34.16, 34.16, 34.16, 34.16}}, -{ "50:256:8:256", {16.75, 34.06, 34.06, 34.06, 34.06, 34.06}}, -{ "50:512:4:256", {16.75, 34.06, 34.06, 34.06, 34.06, 34.06}}, -{"50:256:16:256", {16.65, 33.96, 33.96, 33.96, 33.96, 33.96}}, -{ "50:512:8:256", {16.65, 33.96, 33.96, 33.96, 33.96, 33.96}}, -{ "50:512:4:512", {16.65, 33.96, 33.96, 33.96, 33.96, 33.96}}, -{ "50:1K:8:256", {16.55, 33.86, 33.86, 33.86, 33.86, 33.86}}, -{"50:512:16:256", {16.55, 33.86, 33.86, 33.86, 33.86, 33.86}}, -{ "50:512:8:512", {16.55, 33.86, 33.86, 33.86, 33.86, 33.86}}, -{ "50:1K:16:256", {16.45, 33.76, 33.76, 33.76, 33.76, 33.76}}, -{ "50:1K:8:512", {16.45, 33.76, 33.76, 33.76, 33.76, 33.76}}, -{"50:512:16:512", {16.45, 33.76, 33.76, 33.76, 33.76, 33.76}}, -{ "50:1K:16:512", {16.35, 33.66, 33.66, 33.66, 33.66, 33.66}}, -{ "50:1K:8:1K", {16.35, 33.66, 33.66, 33.66, 33.66, 33.66}}, -{ "50:1K:16:1K", {16.25, 33.56, 33.56, 33.56, 33.56, 33.56}}, -{ "50:4K:16:512", {16.15, 33.46, 33.46, 33.46, 33.46, 33.46}}, -{ "50:4K:16:1K", {16.05, 33.36, 33.36, 33.36, 33.36, 33.36}}, +{ "50:256:2:256", {16.95, 16.95, 16.95, 16.95, 16.95, 16.95}}, +{ "50:256:4:256", {16.85, 16.85, 16.85, 16.85, 16.85, 16.85}}, +{ "50:256:8:256", {16.75, 16.75, 16.75, 16.75, 16.75, 16.75}}, +{ "50:512:4:256", {16.75, 16.75, 16.75, 16.75, 16.75, 16.75}}, +{"50:256:16:256", {16.65, 16.65, 16.65, 16.65, 16.65, 16.65}}, +{ "50:512:8:256", {16.65, 16.65, 16.65, 16.65, 16.65, 16.65}}, +{ "50:512:4:512", {16.65, 16.65, 16.65, 16.65, 16.65, 16.65}}, +{ "50:1K:8:256", {16.55, 16.55, 16.55, 16.55, 16.55, 16.55}}, +{"50:512:16:256", {16.55, 16.55, 16.55, 16.55, 16.55, 16.55}}, +{ "50:512:8:512", {16.55, 16.55, 16.55, 16.55, 16.55, 16.55}}, +{ "50:1K:16:256", {16.45, 16.45, 16.45, 16.45, 16.45, 16.45}}, +{ "50:1K:8:512", {16.45, 16.45, 16.45, 16.45, 16.45, 16.45}}, +{"50:512:16:512", {16.45, 16.45, 16.45, 16.45, 16.45, 16.45}}, +{ "50:1K:16:512", {16.35, 16.35, 16.35, 16.35, 16.35, 16.35}}, +{ "50:1K:8:1K", {16.35, 16.35, 16.35, 16.35, 16.35, 16.35}}, +{ "50:1K:16:1K", {16.25, 16.25, 16.25, 16.25, 16.25, 16.25}}, +{ "50:4K:16:512", {16.15, 16.15, 16.15, 16.15, 16.15, 16.15}}, +{ "50:4K:16:1K", {16.05, 16.05, 16.05, 16.05, 16.05, 16.05}}, // FFT6431 { "51:256:2:256", {34.26, 34.26, 34.26, 34.26, 34.26, 34.26}}, { "51:256:4:256", {34.16, 34.16, 34.16, 34.16, 34.16, 34.16}}, From 90196940295a7595d997bbe4937e16dfae6b37a8 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 9 Oct 2025 04:07:13 +0000 Subject: [PATCH 048/115] Improved M31+M61 carry propagation by using get_balanced_Z61 Created 4 different implementations of i96 data type. Will select one after dabbling with add.cc PTX inline assembly. --- src/cl/carryfused.cl | 2 +- src/cl/carryutil.cl | 25 ++++++++------------ src/cl/math.cl | 54 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 18 deletions(-) diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index cdacc841..fe788df2 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -1737,7 +1737,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( m61_combo_counter = m61_starting_combo_counter; #if ROE - float fltRoundMax = (float) roundMax / (float) 0x0FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float - divide by M61*M31. + float fltRoundMax = (float) roundMax / (float) 0x1FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float - divide by M61. updateStats(bufROE, posROE, fltRoundMax); #elif STATS & (1 << MUL3) updateStats(bufROE, posROE, carryMax); diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 97bb48df..de315ecd 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -406,17 +406,18 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 u32 n31 = get_Z31(u31); u61 = subq(u61, make_Z61(n31), 2); // u61 - u31 u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) - u64 n61 = get_Z61(u61); - -#if 1 //GWBUG - is this better/as good as int96 code? TitanV seems at least as good. - i128 v = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 - // Convert to balanced representation by subtracting M61*M31 - if (((u32)(v >> 64)) & 0xF8000000) v = v - (i128) M31 * (i128) M61; + // The resulting value will be get_Z61(u61) * M31 + n31 and if larger than ~M31*M61/2 return a negative value by subtracting M31 * M61. + // We can save a little work by determining if the result will be large using just u61 and returning (get_Z61(u61) - M61) * M31 + n31. + // This simplifies to get_balanced_Z61(u61) * M31 + n31. + i64 n61 = get_balanced_Z61(u61); - // Optionally calculate roundoff error as proximity to M61*M31/2. 27 bits of accuracy should be sufficient. - u32 roundoff = (u32) abs((i32)(v >> 64)); + // Optionally calculate roundoff error as proximity to M61/2. 28 bits of accuracy should be sufficient. + u32 roundoff = (u32) abs((i32)(n61 >> 32)); *maxROE = max(*maxROE, roundoff); + +#if 1 //GWBUG - is this better/as good as int96 code? TitanV seems at least as good. + i128 v = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 // Mul by 3 and add carry #if MUL3 @@ -426,17 +427,9 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 i96 value = make_i96(v); #else - i96 value = make_i96(n61 >> 1, ((u32) n61 << 31) | n31); // (n61<<31) + n31 i96_sub(&value, n61); - // Convert to balanced representation by subtracting M61*M31 - if (i96_hi32(value) & 0xF8000000) i96_sub(&value, make_i96(0x0FFFFFFF, 0xDFFFFFFF, 0x80000001)); - - // Optionally calculate roundoff error as proximity to M61*M31/2. 27 bits of accuracy should be sufficient. - u32 roundoff = (u32) abs((i32) i96_hi32(value)); - *maxROE = max(*maxROE, roundoff); - // Mul by 3 and add carry #if MUL3 i96_mul(&value, 3); diff --git a/src/cl/math.cl b/src/cl/math.cl index 0d26d3c1..da571dc4 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -10,10 +10,13 @@ u32 lo32(u64 x) { return (u32) x; } u32 hi32(u64 x) { return (u32) (x >> 32); } // A primitive partial implementation of an i96/u96 integer type +// It seems that the clang PTX compiler struggles with switching between u32s and u64s +#if 0 // This was the first cut. Unions are messy as they may require recombining registers too frequently. typedef union { struct { u32 lo32; u32 mid32; u32 hi32; } a; struct { u64 lo64; u32 hi32; } c; } i96; +i96 OVERLOAD make_i96(u128 v) { i96 val; val.c.hi32 = v >> 64, val.c.lo64 = v; return val; } i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.a.hi32 = h, val.a.mid32 = m, val.a.lo32 = l; return val; } i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.a.hi32 = hi32(h), val.a.mid32 = lo32(h), val.a.lo32 = l; return val; } i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.c.hi32 = h, val.c.lo64 = l; return val; } @@ -26,6 +29,55 @@ u64 i96_lo64(i96 val) { return val.c.lo64; } u64 i96_hi64(i96 val) { return ((u64) val.a.hi32 << 32) + val.a.mid32; } u32 i96_lo32(i96 val) { return val.a.lo32; } u32 i96_mid32(i96 val) { return val.a.mid32; } +#elif 0 +// An all u32 implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions +typedef struct { u32 lo32; u32 mid32; u32 hi32; } i96; +i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.hi32 = h, val.mid32 = m, val.lo32 = l; return val; } +i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.hi32 = hi32(h), val.mid32 = lo32(h), val.lo32 = l; return val; } +i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.hi32 = h, val.mid32 = hi32(l); val.lo32 = lo32(l); return val; } +void i96_add(i96 *val, i96 x) { val->lo32 += x.lo32; val->mid32 += x.mid32; val->hi32 += x.hi32 + (val->mid32 < x.mid32); u32 carry = (val->lo32 < x.lo32); val->mid32 += carry; val->hi32 += (val->mid32 < carry); } +void OVERLOAD i96_sub(i96 *val, i96 x) { i96 tmp = *val; val->lo32 -= x.lo32; val->mid32 -= x.mid32; val->hi32 -= x.hi32 + (val->mid32 > tmp.mid32); u32 carry = (val->lo32 > tmp.lo32); tmp = *val; val->mid32 -= carry; val->hi32 -= (val->mid32 > tmp.mid32); } +void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } +void i96_mul(i96 *val, u32 x) { u64 t = (u64)val->lo32 * x; val->lo32 = (u32)t; t = (u64)val->mid32 * x + (t >> 32); val->mid32 = (u32)t; val->hi32 = val->hi32 * x + (u32)(t >> 32); } +u32 i96_hi32(i96 val) { return val.hi32; } +u32 i96_mid32(i96 val) { return val.mid32; } +u32 i96_lo32(i96 val) { return val.lo32; } +u64 i96_lo64(i96 val) { return ((u64) val.mid32 << 32) | val.lo32; } +u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | val.mid32; } +#elif 1 +// A u64 lo32, u32 hi32 implementation. This too would benefit from add.cc ASM instructions. +typedef struct { u64 lo64; u32 hi32; } i96; +i96 OVERLOAD make_i96(u128 v) { i96 val; val.hi32 = v >> 64, val.lo64 = v; return val; } +i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.hi32 = h, val.lo64 = ((u64) m << 32) | l; return val; } +i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.hi32 = hi32(h), val.lo64 = ((u64) lo32(h) << 32) | l; return val; } +i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.hi32 = h, val.lo64 = l; return val; } +u32 i96_hi32(i96 val) { return val.hi32; } +u32 i96_mid32(i96 val) { return hi32(val.lo64); } +u32 i96_lo32(i96 val) { return val.lo64; } +u64 i96_lo64(i96 val) { return val.lo64; } +u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | i96_mid32(val); } +void i96_add(i96 *val, i96 x) { val->lo64 += x.lo64; val->hi32 += x.hi32 + (val->lo64 < x.lo64); } +void OVERLOAD i96_sub(i96 *val, i96 x) { u64 tmp = val->lo64; val->lo64 -= x.lo64; val->hi32 -= x.hi32 + (val->lo64 > tmp); } +void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } +void i96_mul(i96 *val, u32 x) { u64 t = i96_lo32(*val) * (u64)x; u32 lo32 = t; t = i96_mid32(*val) * (u64)x + (t >> 32); u32 mid32 = t; u32 hi32 = val->hi32 * x + (t >> 32); *val = make_i96(hi32, mid32, lo32); } +#elif 1 +// A u128 implementation. This will use more registers. +typedef struct { u128 x; } i96; +i96 OVERLOAD make_i96(u128 v) { i96 val; val.x = v; return val; } +i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.x = ((u128) h << 64) | ((u128) m << 32) | l; return val; } +i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.x = ((u128) h << 32) | l; return val; } +i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.x = ((u128) h << 64) | l; return val; } +u32 i96_hi32(i96 val) { return val.x >> 64; } +u32 i96_mid32(i96 val) { return val.x >> 32; } +u32 i96_lo32(i96 val) { return val.x; } +u64 i96_lo64(i96 val) { return val.x; } +u64 i96_hi64(i96 val) { return val.x >> 32; } +void i96_add(i96 *val, i96 v) { val->x += v.x; } +void OVERLOAD i96_sub(i96 *val, i96 v) { val->x -= v.x; } +void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } +void i96_mul(i96 *val, u32 x) { val->x *= x; } +#endif + // The X2 family of macros and SWAP are #defines because OpenCL does not allow pass by reference. // With NTT support added, we need to turn these macros into overloaded routines. @@ -670,7 +722,7 @@ void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, u32 m61_count) { X2_conjb_internal(a, #elif 1 // Faster version that keeps results in the range 0 .. M61+epsilon u64 OVERLOAD get_Z61(Z61 a) { Z61 m = a - M61; return (m & 0x8000000000000000ULL) ? a : m; } // Get value in range 0 to M61-1 -i64 OVERLOAD get_balanced_Z61(Z61 a) { return (hi32(a) & 0xF0000000) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 +i64 OVERLOAD get_balanced_Z61(Z61 a) { return (a >= 0x1000000000000000ULL) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 // Internal routine to bring Z61 value into the range 0..M61+epsilon Z61 OVERLOAD modM61(Z61 a) { return (a & M61) + (a >> 61); } From 0eff38c7260ada3264790abaa3819fe7684fb680 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 9 Oct 2025 18:09:06 +0000 Subject: [PATCH 049/115] In M31*M61 carry propagation, construct 128-bit v value differently. Saves 2 PTX instructions (clang generated poor code). --- src/cl/carryutil.cl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index de315ecd..ba4354e0 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -417,7 +417,9 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 *maxROE = max(*maxROE, roundoff); #if 1 //GWBUG - is this better/as good as int96 code? TitanV seems at least as good. - i128 v = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 + i64 vhi = n61 >> 33; + u64 vlo = (n61 << 31) | n31; + i128 v = (((i128)vhi << 64) | (i128)vlo) - n61; // n61 * M31 + n31 // Mul by 3 and add carry #if MUL3 From f29abc4136d6b2fdb23ffeb5b6fb045260b27455 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 9 Oct 2025 20:15:53 +0000 Subject: [PATCH 050/115] Finalized decision on using i128 data type to implement i96 math. Cleaned up carryutil code accordingly. --- src/cl/carryutil.cl | 50 +++++++++++++--------------- src/cl/math.cl | 80 ++++++++++++++++++--------------------------- 2 files changed, 54 insertions(+), 76 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index ba4354e0..e3cda20a 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -304,18 +304,21 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, double uInt = fma(u, 4.656612875245796924105750827168e-10, RNDVAL); // Divide by M31 and round to int i64 n64 = RNDVALdoubleToLong(uInt); - i128 v = (((i128) n64 << 31) | n31) - n64; // n64 * M31 + n31 - // Optionally calculate roundoff error float roundoff = (float) fabs(fma(u, 4.656612875245796924105750827168e-10, RNDVAL - uInt)); *maxROE = max(*maxROE, roundoff); + // Compute the value using i96 math + i64 vhi = n64 >> 33; + u64 vlo = ((u64)n64 << 31) | n31; + i96 value = make_i96(vhi, vlo); // (n64 << 31) + n31 + i96_sub(&value, make_i96(n64)); // n64 * M31 + n31 + // Mul by 3 and add carry #if MUL3 - v = v * 3; + i96_mul(&value, 3); #endif - v += inCarry; - i96 value = make_i96((u64) (v >> 32), (u32) v); + i96_add(&value, make_i96(inCarry)); return value; } @@ -372,18 +375,21 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, float uF2int = fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL); // Divide by 2^61 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); - i128 v = (((i128) nF2 << 61) | n61) - nF2; // nF2 * M61 + n61 - // Optionally calculate roundoff error float roundoff = fabs(fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL - uF2int)); *maxROE = max(*maxROE, roundoff); + // Compute the value using i96 math + i32 vhi = nF2 >> 3; + u64 vlo = ((u64)nF2 << 61) | n61; + i96 value = make_i96(vhi, vlo); // (nF2 << 61) + n61 + i96_sub(&value, make_i96(nF2)); // nF2 * M61 + n61 + // Mul by 3 and add carry #if MUL3 - v = v * 3; + i96_mul(&value, 3); #endif - v += inCarry; - i96 value = make_i96((u64) (v >> 32), (u32) v); + i96_add(&value, make_i96(inCarry)); return value; } @@ -415,30 +421,18 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 // Optionally calculate roundoff error as proximity to M61/2. 28 bits of accuracy should be sufficient. u32 roundoff = (u32) abs((i32)(n61 >> 32)); *maxROE = max(*maxROE, roundoff); - -#if 1 //GWBUG - is this better/as good as int96 code? TitanV seems at least as good. - i64 vhi = n61 >> 33; - u64 vlo = (n61 << 31) | n31; - i128 v = (((i128)vhi << 64) | (i128)vlo) - n61; // n61 * M31 + n31 - // Mul by 3 and add carry -#if MUL3 - v = v * 3; -#endif - v = v + inCarry; - i96 value = make_i96(v); - -#else - i96 value = make_i96(n61 >> 1, ((u32) n61 << 31) | n31); // (n61<<31) + n31 - i96_sub(&value, n61); + // Compute the value using i96 math + i64 vhi = n61 >> 33; + u64 vlo = ((u64)n61 << 31) | n31; + i96 value = make_i96(vhi, vlo); // (n61 << 31) + n31 + i96_sub(&value, make_i96(n61)); // n61 * M31 + n31 // Mul by 3 and add carry #if MUL3 i96_mul(&value, 3); #endif - i96_add(&value, make_i96((u32)(inCarry >> 63), (u64) inCarry)); -#endif - + i96_add(&value, make_i96(inCarry)); return value; } diff --git a/src/cl/math.cl b/src/cl/math.cl index da571dc4..ac1749a8 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -9,72 +9,56 @@ u32 lo32(u64 x) { return (u32) x; } u32 hi32(u64 x) { return (u32) (x >> 32); } -// A primitive partial implementation of an i96/u96 integer type -// It seems that the clang PTX compiler struggles with switching between u32s and u64s -#if 0 // This was the first cut. Unions are messy as they may require recombining registers too frequently. -typedef union { - struct { u32 lo32; u32 mid32; u32 hi32; } a; - struct { u64 lo64; u32 hi32; } c; -} i96; -i96 OVERLOAD make_i96(u128 v) { i96 val; val.c.hi32 = v >> 64, val.c.lo64 = v; return val; } -i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.a.hi32 = h, val.a.mid32 = m, val.a.lo32 = l; return val; } -i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.a.hi32 = hi32(h), val.a.mid32 = lo32(h), val.a.lo32 = l; return val; } -i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.c.hi32 = h, val.c.lo64 = l; return val; } -void i96_add(i96 *val, i96 x) { u64 lo64 = val->c.lo64 + x.c.lo64; val->c.hi32 += x.c.hi32 + (lo64 < val->c.lo64); val->c.lo64 = lo64; } -void OVERLOAD i96_sub(i96 *val, i96 x) { u64 lo64 = val->c.lo64 - x.c.lo64; val->c.hi32 -= x.c.hi32 + (lo64 > val->c.lo64); val->c.lo64 = lo64; } -void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } -void i96_mul(i96 *val, u32 x) { u64 t = (u64)val->a.lo32 * x; val->a.lo32 = (u32)t; t = (u64)val->a.mid32 * x + (t >> 32); val->a.mid32 = (u32)t; val->a.hi32 = val->a.hi32 * x + (u32)(t >> 32); } -u32 i96_hi32(i96 val) { return val.c.hi32; } -u64 i96_lo64(i96 val) { return val.c.lo64; } -u64 i96_hi64(i96 val) { return ((u64) val.a.hi32 << 32) + val.a.mid32; } -u32 i96_lo32(i96 val) { return val.a.lo32; } -u32 i96_mid32(i96 val) { return val.a.mid32; } -#elif 0 -// An all u32 implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions +// A primitive partial implementation of an i96 integer type +#if 0 +// An all u32 implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions. +// This version might be best on AMD and Intel if we can generate add-with-carry instructions. typedef struct { u32 lo32; u32 mid32; u32 hi32; } i96; -i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.hi32 = h, val.mid32 = m, val.lo32 = l; return val; } -i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.hi32 = hi32(h), val.mid32 = lo32(h), val.lo32 = l; return val; } -i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.hi32 = h, val.mid32 = hi32(l); val.lo32 = lo32(l); return val; } -void i96_add(i96 *val, i96 x) { val->lo32 += x.lo32; val->mid32 += x.mid32; val->hi32 += x.hi32 + (val->mid32 < x.mid32); u32 carry = (val->lo32 < x.lo32); val->mid32 += carry; val->hi32 += (val->mid32 < carry); } -void OVERLOAD i96_sub(i96 *val, i96 x) { i96 tmp = *val; val->lo32 -= x.lo32; val->mid32 -= x.mid32; val->hi32 -= x.hi32 + (val->mid32 > tmp.mid32); u32 carry = (val->lo32 > tmp.lo32); tmp = *val; val->mid32 -= carry; val->hi32 -= (val->mid32 > tmp.mid32); } -void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } -void i96_mul(i96 *val, u32 x) { u64 t = (u64)val->lo32 * x; val->lo32 = (u32)t; t = (u64)val->mid32 * x + (t >> 32); val->mid32 = (u32)t; val->hi32 = val->hi32 * x + (u32)(t >> 32); } +i96 OVERLOAD make_i96(i128 v) { i96 val; val.hi32 = (u128)v >> 64, val.mid32 = (u64)v >> 32, val.lo32 = v; return val; } +i96 OVERLOAD make_i96(i64 v) { i96 val; val.hi32 = v >> 63, val.mid32 = v >> 32, val.lo32 = v; return val; } +i96 OVERLOAD make_i96(i32 v) { i96 val; val.hi32 = v >> 31, val.mid32 = v >> 31, val.lo32 = v; return val; } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi, val.mid32 = lo >> 32, val.lo32 = lo; return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { i96 val; val.hi32 = hi, val.mid32 = lo >> 32, val.lo32 = lo; return val; } u32 i96_hi32(i96 val) { return val.hi32; } u32 i96_mid32(i96 val) { return val.mid32; } u32 i96_lo32(i96 val) { return val.lo32; } u64 i96_lo64(i96 val) { return ((u64) val.mid32 << 32) | val.lo32; } u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | val.mid32; } -#elif 1 -// A u64 lo32, u32 hi32 implementation. This too would benefit from add.cc ASM instructions. +void i96_add(i96 *val, i96 x) { val->lo32 += x.lo32; val->mid32 += x.mid32; val->hi32 += x.hi32 + (val->mid32 < x.mid32); u32 carry = (val->lo32 < x.lo32); val->mid32 += carry; val->hi32 += (val->mid32 < carry); } +void i96_sub(i96 *val, i96 x) { i96 tmp = *val; val->lo32 -= x.lo32; val->mid32 -= x.mid32; val->hi32 -= x.hi32 + (val->mid32 > tmp.mid32); u32 carry = (val->lo32 > tmp.lo32); tmp = *val; val->mid32 -= carry; val->hi32 -= (val->mid32 > tmp.mid32); } +void i96_mul(i96 *val, u32 x) { u64 t = (u64)val->lo32 * x; val->lo32 = (u32)t; t = (u64)val->mid32 * x + (t >> 32); val->mid32 = (u32)t; val->hi32 = val->hi32 * x + (u32)(t >> 32); } +#elif 0 +// A u64 lo32, u32 hi32 implementation. This too would benefit from add with carry instructions. +// On nVidia, the clang optimizer kept the hi32 value as 64-bits! typedef struct { u64 lo64; u32 hi32; } i96; -i96 OVERLOAD make_i96(u128 v) { i96 val; val.hi32 = v >> 64, val.lo64 = v; return val; } -i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.hi32 = h, val.lo64 = ((u64) m << 32) | l; return val; } -i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.hi32 = hi32(h), val.lo64 = ((u64) lo32(h) << 32) | l; return val; } -i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.hi32 = h, val.lo64 = l; return val; } +i96 OVERLOAD make_i96(i128 v) { i96 val; val.hi32 = (u128)v >> 64, val.lo64 = v; return val; } +i96 OVERLOAD make_i96(i64 v) { i96 val; val.hi32 = v >> 63, val.lo64 = v; return val; } +i96 OVERLOAD make_i96(i32 v) { return make_i96((i64)v); } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi, val.lo64 = lo; return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { i96 val; val.hi32 = hi, val.lo64 = lo; return val; } u32 i96_hi32(i96 val) { return val.hi32; } u32 i96_mid32(i96 val) { return hi32(val.lo64); } u32 i96_lo32(i96 val) { return val.lo64; } u64 i96_lo64(i96 val) { return val.lo64; } u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | i96_mid32(val); } void i96_add(i96 *val, i96 x) { val->lo64 += x.lo64; val->hi32 += x.hi32 + (val->lo64 < x.lo64); } -void OVERLOAD i96_sub(i96 *val, i96 x) { u64 tmp = val->lo64; val->lo64 -= x.lo64; val->hi32 -= x.hi32 + (val->lo64 > tmp); } -void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } +void i96_sub(i96 *val, i96 x) { u64 tmp = val->lo64; val->lo64 -= x.lo64; val->hi32 -= x.hi32 + (val->lo64 > tmp); } void i96_mul(i96 *val, u32 x) { u64 t = i96_lo32(*val) * (u64)x; u32 lo32 = t; t = i96_mid32(*val) * (u64)x + (t >> 32); u32 mid32 = t; u32 hi32 = val->hi32 * x + (t >> 32); *val = make_i96(hi32, mid32, lo32); } #elif 1 -// A u128 implementation. This will use more registers. -typedef struct { u128 x; } i96; -i96 OVERLOAD make_i96(u128 v) { i96 val; val.x = v; return val; } -i96 OVERLOAD make_i96(u32 h, u32 m, u32 l) { i96 val; val.x = ((u128) h << 64) | ((u128) m << 32) | l; return val; } -i96 OVERLOAD make_i96(u64 h, u32 l) { i96 val; val.x = ((u128) h << 32) | l; return val; } -i96 OVERLOAD make_i96(u32 h, u64 l) { i96 val; val.x = ((u128) h << 64) | l; return val; } -u32 i96_hi32(i96 val) { return val.x >> 64; } -u32 i96_mid32(i96 val) { return val.x >> 32; } +// An i128 implementation. This might use more GPU registers. nVidia likes this version. +typedef struct { i128 x; } i96; +i96 OVERLOAD make_i96(i128 v) { i96 val; val.x = v; return val; } +i96 OVERLOAD make_i96(i64 v) { return make_i96((i128)v); } +i96 OVERLOAD make_i96(i32 v) { return make_i96((i128)v); } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.x = ((u128)hi << 64) + lo; return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { return make_i96((i64)hi, lo); } +u32 i96_hi32(i96 val) { return (u128)val.x >> 64; } +u32 i96_mid32(i96 val) { return (u64)val.x >> 32; } u32 i96_lo32(i96 val) { return val.x; } u64 i96_lo64(i96 val) { return val.x; } -u64 i96_hi64(i96 val) { return val.x >> 32; } +u64 i96_hi64(i96 val) { return (u128)val.x >> 32; } void i96_add(i96 *val, i96 v) { val->x += v.x; } -void OVERLOAD i96_sub(i96 *val, i96 v) { val->x -= v.x; } -void OVERLOAD i96_sub(i96 *val, u64 x) { i96_sub(val, make_i96(0, x)); } +void i96_sub(i96 *val, i96 v) { val->x -= v.x; } void i96_mul(i96 *val, u32 x) { val->x *= x; } #endif From c07ed343e9de38c8f2d6f2187b72d8d16b7249e0 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 10 Oct 2025 20:48:33 +0000 Subject: [PATCH 051/115] Added quick=n to -tune parameters for config.txt options testing. n ranges from 1 (slowest,more accurate) to 10 (fastest,least accurate). Changed -tune FFT timings to use quicker/fewer iterations as exponent increases. --- src/Gpu.cpp | 32 ++++++++++++++------------ src/Gpu.h | 2 +- src/tune.cpp | 64 ++++++++++++++++++++++++++++------------------------ 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index d18dbaae..8e436554 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -148,10 +148,10 @@ Weights genWeights(FFTConfig fft, u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { for (u32 block = 0; block < nW; ++block) { for (u32 rep = 0; rep < 2; ++rep) { if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } - } - } - } - bits.push_back(b.to_ulong()); + } + } + } + bits.push_back(b.to_ulong()); } } assert(bits.size() == N / 32); @@ -1277,7 +1277,7 @@ void Gpu::doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u3 zAvg.update(z, roeSq.N); if (roeSq.max > 0.005) log("%sZ=%.0f (avg %.1f), ROEmax=%.3f, ROEavg=%.3f. %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), - z, zAvg.avg(), roeSq.max, roeSq.mean, (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); + z, zAvg.avg(), roeSq.max, roeSq.mean, (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); else log("%sZ=%.0f (avg %.1f) %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), z, zAvg.avg(), (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); @@ -1601,18 +1601,20 @@ tuple Gpu::measureROE(bool quick) { return {ok, res, roes.first, roes.second}; } -double Gpu::timePRP() { +double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest) to 10 (quickest, shortest) u32 blockSize{}, iters{}, warmup{}; - if (true) { - blockSize = 200; - iters = 1000; - warmup = 30; - } else { - blockSize = 1000; - iters = 10'000; - warmup = 100; - } + if (quick == 10) iters = 400, blockSize = 200; + else if (quick == 9) iters = 600, blockSize = 300; + else if (quick == 8) iters = 900, blockSize = 300; + else if (quick == 7) iters = 1200, blockSize = 400; + else if (quick == 6) iters = 1800, blockSize = 600; + else if (quick == 5) iters = 3000, blockSize = 1000; + else if (quick == 4) iters = 5000, blockSize = 1000; + else if (quick == 3) iters = 8000, blockSize = 1000; + else if (quick == 2) iters = 12000, blockSize = 1000; + else if (quick == 1) iters = 20000, blockSize = 1000; + warmup = 20; assert(iters % blockSize == 0); diff --git a/src/Gpu.h b/src/Gpu.h index 20a4f713..65708922 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -296,7 +296,7 @@ class Gpu { LLResult isPrimeLL(const Task& task); array isCERT(const Task& task); - double timePRP(); + double timePRP(int quick = 7); tuple measureROE(bool quick); tuple measureCarry(); diff --git a/src/tune.cpp b/src/tune.cpp index 30be06da..c7939d60 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -346,6 +346,7 @@ void Tune::tune() { bool tune_config = 1; bool time_FFTs = 0; bool time_NTTs = 0; + int quick = 7; // Run config from slowest (quick=1) to fastest (quick=10) u64 min_exponent = 75000000; u64 max_exponent = 350000000; if (!args->fftSpec.empty()) { min_exponent = 0; max_exponent = 1000000000000ull; } @@ -358,10 +359,13 @@ void Tune::tune() { if (s == "ntt") time_NTTs = 1; auto keyVal = split(s, '='); if (keyVal.size() == 2) { + if (keyVal.front() == "quick") quick = stod(keyVal.back()); if (keyVal.front() == "minexp") min_exponent = stoull(keyVal.back()); if (keyVal.front() == "maxexp") max_exponent = stoull(keyVal.back()); } } + if (quick < 1) quick = 1; + if (quick > 10) quick = 10; // Look for best settings of various options. Append best settings to config.txt. if (tune_config) { @@ -391,7 +395,7 @@ void Tune::tune() { } if (time_NTTs) { defaultNTTShape = FFTShape(FFT3161, 512, 8, 512); - defaultShape = &defaultNTTShape; + defaultShape = &defaultNTTShape; } } // No user specifications. Time an FP64 FFT and a GF31*GF61 NTT to see if the GPU is more suited for FP64 work or NTT work. @@ -399,11 +403,11 @@ void Tune::tune() { log("Checking whether this GPU is better suited for double-precision FFTs or integer NTTs.\n"); defaultFFTShape = FFTShape(FFT64, 512, 16, 512); FFTConfig fft{defaultFFTShape, 101, CARRY_32}; - double fp64_time = Gpu::make(q, 141000001, shared, fft, {}, false)->timePRP(); + double fp64_time = Gpu::make(q, 141000001, shared, fft, {}, false)->timePRP(quick); log("Time for FP64 FFT %12s is %6.1f\n", fft.spec().c_str(), fp64_time); defaultNTTShape = FFTShape(FFT3161, 512, 8, 512); FFTConfig ntt{defaultNTTShape, 202, CARRY_AUTO}; - double ntt_time = Gpu::make(q, 141000001, shared, ntt, {}, false)->timePRP(); + double ntt_time = Gpu::make(q, 141000001, shared, ntt, {}, false)->timePRP(quick); log("Time for M31*M61 NTT %12s is %6.1f\n", ntt.spec().c_str(), ntt_time); if (fp64_time < ntt_time) { defaultShape = &defaultFFTShape; @@ -447,7 +451,7 @@ void Tune::tune() { for (u32 in_sizex : {8, 16, 32}) { args->flags["IN_WG"] = to_string(in_wg); args->flags["IN_SIZEX"] = to_string(in_sizex); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using IN_WG=%u, IN_SIZEX=%u is %6.1f\n", fft.spec().c_str(), in_wg, in_sizex, cost); if (in_wg == current_in_wg && in_sizex == current_in_sizex) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_in_wg = in_wg; best_in_sizex = in_sizex; } @@ -469,7 +473,7 @@ void Tune::tune() { for (u32 out_sizex : {8, 16, 32}) { args->flags["OUT_WG"] = to_string(out_wg); args->flags["OUT_SIZEX"] = to_string(out_sizex); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using OUT_WG=%u, OUT_SIZEX=%u is %6.1f\n", fft.spec().c_str(), out_wg, out_sizex, cost); if (out_wg == current_out_wg && out_sizex == current_out_sizex) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_out_wg = out_wg; best_out_sizex = out_sizex; } @@ -492,7 +496,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 pad : {0, 64, 128, 256, 512}) { args->flags["PAD"] = to_string(pad); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using PAD=%u is %6.1f\n", fft.spec().c_str(), pad, cost); if (pad == current_pad) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_pad = pad; } @@ -512,7 +516,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 nontemporal : {0, 1}) { args->flags["NONTEMPORAL"] = to_string(nontemporal); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); if (nontemporal == current_nontemporal) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_nontemporal = nontemporal; } @@ -532,7 +536,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 fast_barrier : {0, 1}) { args->flags["FAST_BARRIER"] = to_string(fast_barrier); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using FAST_BARRIER=%u is %6.1f\n", fft.spec().c_str(), fast_barrier, cost); if (fast_barrier == current_fast_barrier) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_fast_barrier = fast_barrier; } @@ -552,7 +556,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tail_kernels : {0, 1, 2, 3}) { args->flags["TAIL_KERNELS"] = to_string(tail_kernels); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TAIL_KERNELS=%u is %6.1f\n", fft.spec().c_str(), tail_kernels, cost); if (tail_kernels == current_tail_kernels) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_kernels = tail_kernels; } @@ -575,7 +579,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tail_trigs : {0, 1, 2}) { args->flags["TAIL_TRIGS"] = to_string(tail_trigs); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TAIL_TRIGS=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); if (tail_trigs == current_tail_trigs) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } @@ -596,7 +600,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tail_trigs : {0, 1}) { args->flags["TAIL_TRIGS31"] = to_string(tail_trigs); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TAIL_TRIGS31=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); if (tail_trigs == current_tail_trigs) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } @@ -617,7 +621,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tail_trigs : {0, 1, 2}) { args->flags["TAIL_TRIGS32"] = to_string(tail_trigs); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TAIL_TRIGS32=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); if (tail_trigs == current_tail_trigs) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } @@ -638,7 +642,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tail_trigs : {0, 1}) { args->flags["TAIL_TRIGS61"] = to_string(tail_trigs); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TAIL_TRIGS61=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); if (tail_trigs == current_tail_trigs) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } @@ -658,7 +662,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tabmul_chain : {0, 1}) { args->flags["TABMUL_CHAIN"] = to_string(tabmul_chain); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TABMUL_CHAIN=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); if (tabmul_chain == current_tabmul_chain) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } @@ -679,7 +683,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tabmul_chain : {0, 1}) { args->flags["TABMUL_CHAIN31"] = to_string(tabmul_chain); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TABMUL_CHAIN31=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); if (tabmul_chain == current_tabmul_chain) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } @@ -700,7 +704,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tabmul_chain : {0, 1}) { args->flags["TABMUL_CHAIN32"] = to_string(tabmul_chain); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TABMUL_CHAIN32=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); if (tabmul_chain == current_tabmul_chain) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } @@ -721,7 +725,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 tabmul_chain : {0, 1}) { args->flags["TABMUL_CHAIN61"] = to_string(tabmul_chain); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TABMUL_CHAIN61=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); if (tabmul_chain == current_tabmul_chain) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } @@ -741,7 +745,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 unroll_w : {0, 1}) { args->flags["UNROLL_W"] = to_string(unroll_w); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using UNROLL_W=%u is %6.1f\n", fft.spec().c_str(), unroll_w, cost); if (unroll_w == current_unroll_w) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_w = unroll_w; } @@ -761,7 +765,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 unroll_h : {0, 1}) { args->flags["UNROLL_H"] = to_string(unroll_h); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using UNROLL_H=%u is %6.1f\n", fft.spec().c_str(), unroll_h, cost); if (unroll_h == current_unroll_h) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_h = unroll_h; } @@ -781,7 +785,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 zerohack_w : {0, 1}) { args->flags["ZEROHACK_W"] = to_string(zerohack_w); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using ZEROHACK_W=%u is %6.1f\n", fft.spec().c_str(), zerohack_w, cost); if (zerohack_w == current_zerohack_w) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_w = zerohack_w; } @@ -801,7 +805,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 zerohack_h : {0, 1}) { args->flags["ZEROHACK_H"] = to_string(zerohack_h); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using ZEROHACK_H=%u is %6.1f\n", fft.spec().c_str(), zerohack_h, cost); if (zerohack_h == current_zerohack_h) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_h = zerohack_h; } @@ -821,7 +825,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 middle_in_lds_transpose : {0, 1}) { args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); if (middle_in_lds_transpose == current_middle_in_lds_transpose) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } @@ -841,7 +845,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 middle_out_lds_transpose : {0, 1}) { args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); if (middle_out_lds_transpose == current_middle_out_lds_transpose) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } @@ -861,7 +865,7 @@ void Tune::tune() { double current_cost = -1.0; for (u32 biglit : {0, 1}) { args->flags["BIGLIT"] = to_string(biglit); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using BIGLIT=%u is %6.1f\n", fft.spec().c_str(), biglit, cost); if (biglit == current_biglit) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_biglit = biglit; } @@ -934,8 +938,10 @@ skip_1K_256 = 0; // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp()); - - // Loop through all possible variants +//GW: If user specified a quick != 7, adjust the formula below??? + quick = (exponent < 50000000) ? 6 : (exponent < 150000000) ? 7 : (exponent < 350000000) ? 8 : 10; + + // Loop through all possible variants for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { // Only FP64 code supports variants @@ -983,7 +989,7 @@ skip_1K_256 = 0; if (w == 0 && !AMDGPU) continue; if (w == 0 && test.width > 1024) continue; FFTConfig fft{test, variant_WMH (w, 0, 1), CARRY_32}; - cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(); + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(quick); log("Fast width search %6.1f %12s\n", cost, fft.spec().c_str()); if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_width = w; } } @@ -1005,7 +1011,7 @@ skip_1K_256 = 0; if (h == 0 && !AMDGPU) continue; if (h == 0 && test.height > 1024) continue; FFTConfig fft{test, variant_WMH (1, 0, h), CARRY_32}; - cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(); + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(quick); log("Fast height search %6.1f %12s\n", cost, fft.spec().c_str()); if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_height = h; } } @@ -1035,7 +1041,7 @@ skip_1K_256 = 0; // Skip middle = 1, CARRY_32 if maximum exponent would be the same as middle = 0, CARRY_32 if (variant_M(variant) > 0 && carry == CARRY_32 && fft.maxExp() <= FFTConfig{shape, variant - 10, CARRY_32}.maxExp()) continue; - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); bool isUseful = TuneEntry{cost, fft}.update(results); log("%c %6.1f %12s %9lu\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); } From ea270809b73dbe07bb83f6517005fbb6e15bc568 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 12 Oct 2025 00:38:34 +0000 Subject: [PATCH 052/115] Changed the way sleep time is computed when Queue is full. A shorter wait for NTTs and hybrid FFTs since they execute kernels to do a squaring. --- src/Gpu.cpp | 1 + src/Queue.cpp | 9 +++++---- src/Queue.h | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 8e436554..703a4ab6 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -725,6 +725,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& selftestTrig(); } + queue->setSquareKernels(1 + 3 * (fft.FFT_FP64 + fft.FFT_FP32 + fft.NTT_GF31 + fft.NTT_GF61)); queue->finish(); } diff --git a/src/Queue.cpp b/src/Queue.cpp index 829be7ae..8ef83adc 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -24,11 +24,12 @@ Queue::Queue(const Context& context, bool profile) : markerEvent{}, markerQueued(false), queueCount(0), - squareTime(50) + squareTime(50), + squareKernels(4) { // Formerly a constant (thus the CAPS). nVidia is 3% CPU load at 400 or 500, and 35% load at 800 on my Linux machine. // AMD is just over 2% load at 1600 and 3200 on the same Linux machine. Marginally better timings(?) at 3200. - MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings + MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings (if squareKernels = 4) } void Queue::writeTE(cl_mem buf, u64 size, const void* data, TimeInfo* tInfo) { @@ -102,8 +103,8 @@ void Queue::waitForMarkerEvent() { // By default, nVidia finish causes a CPU busy wait. Instead, sleep for a while. Since we know how many items are enqueued after the marker we can make an // educated guess of how long to sleep to keep CPU overhead low. while (getEventInfo(markerEvent) != CL_COMPLETE) { - // There are 4 kernels per squaring. Don't overestimate sleep time. Divide by 10 instead of 4. - std::this_thread::sleep_for(std::chrono::microseconds(1 + queueCount * squareTime / 10)); + // There are 4, 7, or 10 kernels per squaring. Don't overestimate sleep time. Divide by much more than the number of kernels. + std::this_thread::sleep_for(std::chrono::microseconds(1 + queueCount * squareTime / squareKernels / 2)); } markerQueued = false; } diff --git a/src/Queue.h b/src/Queue.h index 06efa159..28aea967 100644 --- a/src/Queue.h +++ b/src/Queue.h @@ -51,6 +51,7 @@ class Queue : public QueueHolder { void finish(); void setSquareTime(int); // Set the time to do one squaring (in microseconds) + void setSquareKernels(int n) { squareKernels = n; } private: // This replaces the "call queue->finish every 400 squarings" code in Gpu.cpp. Solves the busy wait on nVidia GPUs. int MAX_QUEUE_COUNT; // Queue size before a marker will be enqueued. Typically, 100 to 1000 squarings. @@ -58,6 +59,7 @@ class Queue : public QueueHolder { bool markerQueued; // TRUE if a marker and event have been queued int queueCount; // Count of items added to the queue since last marker int squareTime; // Time to do one squaring (in microseconds) + int squareKernels; // Number of kernels in one squaring void queueMarkerEvent(); // Queue the marker event void waitForMarkerEvent(); // Wait for marker event to complete }; From 4cd8a387d34e7044838b20705972f804b709bcc0 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 13 Oct 2025 02:18:35 +0000 Subject: [PATCH 053/115] Deleted code for slower options. Faster FP32+M31+M61 carry propagation. --- src/cl/carryutil.cl | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index e3cda20a..3e3b3968 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -60,10 +60,9 @@ i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; ret //i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; if (bits == 1) return -(u & 1); return ulowFixedBits(u, bits - 1) - (u & (1 << bits)); } i32 OVERLOAD lowFixedBits(u32 u, const u32 bits) { return lowFixedBits((i32) u, bits); } #endif -// Return signed low bits where number of bits is known at compile time (number of bits can be 1 to 63) -//i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u << (64 - bits)) >> (64 - bits)); } -//i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u64) lowFixedBits((i32) ((u64) u >> 32), bits - 32) << 32) | (u32) u; } -i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return (i64) ulowFixedBits(u, bits - 1) - (u & (1LL << (bits - 1))); } +// Return signed low bits where number of bits is known at compile time (number of bits can be 1 to 63). The two versions are the same speed on TitanV. +i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u << (64 - bits)) >> (64 - bits)); } +//i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return (i64) ulowFixedBits(u, bits - 1) - (u & (1LL << (bits - 1))); } i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64) u, bits); } // Extract 32 bits from a 64-bit value @@ -371,7 +370,9 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, u64 n61 = get_Z61(u61); // The final result must be n61 mod M61. Use FP32 data to calculate this value. - uF2 = fma(uF2, F2_invWeight, - (float) n61); // This should be close to a multiple of M61 +// float n61f = (float)n61; // Convert n61 to float + float n61f = (float)((u32)(n61 >> 32)) * 4294967296.0f; // Conversion from u64 to float might be slow, this might be faster + uF2 = fma(uF2, F2_invWeight, -n61f); // This should be close to a multiple of M61 float uF2int = fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL); // Divide by 2^61 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); @@ -458,10 +459,10 @@ i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_ u64 n61 = get_Z61(u61); i128 n3161 = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 -//GW - computing (float)n3161 may be complicated and we don't need all that precision. . Would forming a smaller value (like just n61) to float be faster with the *M31 taken into account with the constant? - // The final result must be n3161 mod M31*M61. Use FP32 data to calculate this value. - uF2 = fma(uF2, F2_invWeight, - (float) n3161); // This should be close to a multiple of M31*M61 +// float n3161f = (float)n3161; // Convert n3161 to float + float n3161f = (float)((u32)(n61 >> 32)) * 9223372036854775808.0f; // Conversion from i128 to float might be slow, this might be faster + uF2 = fma(uF2, F2_invWeight, -n3161f); // This should be close to a multiple of M31*M61 float uF2int = fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL); // Divide by M31*M61 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); @@ -499,12 +500,6 @@ Word OVERLOAD carryStep(i128 x, i64 *outCarry, bool isBigWord) { Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); -//GWBUG - is this ever faster? Not on TitanV -//i128 x128 = ((i128) (i64) i96_hi64(x) << 32) | i96_lo32(x); -//i64 w = ((i64) x128 << (64 - nBits)) >> (64 - nBits); -//*outCarry = (i64) (x128 >> nBits) + (w < 0); -//return w; - // This code can be tricky because we must not shift i32 or u32 variables by 32. #if EXP / NWORDS >= 33 //GWBUG Would the EXP / NWORDS == 32 code be just as fast? i64 xhi = i96_hi64(x); @@ -673,10 +668,6 @@ Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); #if EXP / NWORDS >= 32 // nBits is 32 or more -// i64 w = lowFixedBits(i96_lo64(x), bigwordBits); -// i64 xhi = ((i64) i96_hi64(x) >> (bigwordBits - 32)) + (w < 0); -// *outCarry = xhi << (bigwordBits - nBits); -// or this: u64 xlo = i96_lo64(x); u64 xlo_topbit = xlo & (1ULL << (bigwordBits - 1)); i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; From 9f20b388d2ebec1062dd32d4b1c318244aa58525 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 13 Oct 2025 05:01:45 +0000 Subject: [PATCH 054/115] More cleanup of less efficient carry propagation options. --- src/cl/carryutil.cl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 3e3b3968..b029a0e3 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -501,19 +501,15 @@ Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); // This code can be tricky because we must not shift i32 or u32 variables by 32. -#if EXP / NWORDS >= 33 //GWBUG Would the EXP / NWORDS == 32 code be just as fast? +#if EXP / NWORDS >= 33 i64 xhi = i96_hi64(x); i64 w = lowBits(xhi, nBits - 32); -// xhi -= w; //GWBUG - is (w < 0) version faster? -// *outCarry = xhi >> (nBits - 32); - *outCarry = (xhi >> (nBits - 32)) + (w < 0); + *outCarry = (xhi - w) >> (nBits - 32); return (w << 32) | i96_lo32(x); #elif EXP / NWORDS == 32 i64 xhi = i96_hi64(x); i64 w = lowBits(i96_lo64(x), nBits); -// xhi -= w >> 32; -// *outCarry = xhi >> (nBits - 32); //GWBUG - Is this ever faster than adding (w < 0)??? - *outCarry = (xhi >> (nBits - 32)) + (w < 0); + *outCarry = (xhi - (w >> 32)) >> (nBits - 32); return w; #elif EXP / NWORDS == 31 i64 w = lowBits(i96_lo64(x), nBits); From 28e683ab2302b19a5d526af784d85a4de39d4df1 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 15 Oct 2025 01:53:13 +0000 Subject: [PATCH 055/115] Added JSON text for type 4 hybrid FFT --- src/Task.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Task.cpp b/src/Task.cpp index 177128e3..3e277786 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -87,7 +87,7 @@ string ffttype(FFTConfig fft) { return fft.shape.fft_type == FFT64 ? "FP64" : fft.shape.fft_type == FFT3161 ? "M31+M61" : fft.shape.fft_type == FFT61 ? "M61" : fft.shape.fft_type == FFT3261 ? "FP32+M61" : fft.shape.fft_type == FFT31 ? "M31" : fft.shape.fft_type == FFT3231 ? "FP32+M31" : - fft.shape.fft_type == FFT32 ? "FP32" : fft.shape.fft_type == FFT6431 ? "FP64+M31" : "unknown"; + fft.shape.fft_type == FFT32 ? "FP32" : fft.shape.fft_type == FFT6431 ? "FP64+M31" : fft.shape.fft_type == FFT323164 ? "FP32+M31+M61" : "unknown"; } string json(const vector& v) { From da1375a67c1308371e545cde410ee297bcc3227d Mon Sep 17 00:00:00 2001 From: george Date: Wed, 15 Oct 2025 02:06:53 +0000 Subject: [PATCH 056/115] Fixed typo in last fix --- src/Task.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Task.cpp b/src/Task.cpp index 3e277786..540010fe 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -87,7 +87,7 @@ string ffttype(FFTConfig fft) { return fft.shape.fft_type == FFT64 ? "FP64" : fft.shape.fft_type == FFT3161 ? "M31+M61" : fft.shape.fft_type == FFT61 ? "M61" : fft.shape.fft_type == FFT3261 ? "FP32+M61" : fft.shape.fft_type == FFT31 ? "M31" : fft.shape.fft_type == FFT3231 ? "FP32+M31" : - fft.shape.fft_type == FFT32 ? "FP32" : fft.shape.fft_type == FFT6431 ? "FP64+M31" : fft.shape.fft_type == FFT323164 ? "FP32+M31+M61" : "unknown"; + fft.shape.fft_type == FFT32 ? "FP32" : fft.shape.fft_type == FFT6431 ? "FP64+M31" : fft.shape.fft_type == FFT323161 ? "FP32+M31+M61" : "unknown"; } string json(const vector& v) { From b744c6ea3bdf725f945caf6e4cb075bdae630c12 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 15 Oct 2025 21:15:31 +0000 Subject: [PATCH 057/115] Faster complex mul for GF61. Attempted inline PTX code with disappointing results. --- src/cl/base.cl | 15 +++++++++----- src/cl/math.cl | 53 +++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/cl/base.cl b/src/cl/base.cl index 7acd0fe2..781e832a 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -19,7 +19,8 @@ CARRY_LEN NW NH AMDGPU : if this is an AMD GPU -HAS_ASM : set if we believe __asm() can be used +HAS_ASM : set if we believe __asm() can be used for AMD GCN +HAS_PTX : set if we believe __asm() can be used for nVidia PTX -- Derived from above: BIG_HEIGHT == SMALL_HEIGHT * MIDDLE @@ -56,12 +57,16 @@ G_H "group height" == SMALL_HEIGHT / NH //__builtin_assume(condition) #endif // DEBUG -#if AMDGPU -// On AMDGPU the default is HAS_ASM -#if !NO_ASM +#if NO_ASM +#define HAS_ASM 0 +#define HAS_PTX 0 +#elif AMDGPU #define HAS_ASM 1 +#define HAS_PTX 0 +#else // Assume it is as nVidia GPU (can C code detect nVidia like it does for AMD?) +#define HAS_ASM 0 +#define HAS_PTX 1 #endif -#endif // AMDGPU // Default is not adding -2 to results for LL #if !defined(LL) diff --git a/src/cl/math.cl b/src/cl/math.cl index ac1749a8..309b1fe5 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -9,6 +9,30 @@ u32 lo32(u64 x) { return (u32) x; } u32 hi32(u64 x) { return (u32) (x >> 32); } +// Multiply and add primitives + +u128 mad32(u32 a, u32 b, u64 c) { +#if 0 && HAS_PTX // Same speed on TitanV, any gain may be too small to measure + u32 reslo, reshi; + __asm("mad.lo.cc.u32 %0, %1, %2, %3;" : "=r"(reslo) : "r"(a), "r"(b), "r"((u32) c)); + __asm("madc.hi.u32 %0, %1, %2, %3;" : "=r"(reshi) : "r"(a), "r"(b), "r"((u32) (c >> 32))); + return ((u64)reshi << 32) | reslo; +#else + return (u64) a * (u64) b + c; +#endif +} + +u128 mad64(u64 a, u64 b, u128 c) { +#if 0 && HAS_PTX // Slower on TitanV, don't understand why + u64 reslo, reshi; + __asm("mad.lo.cc.u64 %0, %1, %2, %3;" : "=l"(reslo) : "l"(a), "l"(b), "l"((u64) c)); + __asm("madc.hi.u64 %0, %1, %2, %3;" : "=l"(reshi) : "l"(a), "l"(b), "l"((u64) (c >> 64))); + return ((u128)reshi << 64) | reslo; +#else + return (u128) a * (u128) b + c; +#endif +} + // A primitive partial implementation of an i96 integer type #if 0 // An all u32 implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions. @@ -499,7 +523,7 @@ GF31 OVERLOAD csq_subi(GF31 a, GF31 c) { } // Complex mul -#if 1 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. +#if 0 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. GF31 OVERLOAD cmul(GF31 a, GF31 b) { u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 u64 k2 = a.x * (u64) (b.y + neg(b.x)); @@ -511,10 +535,8 @@ GF31 OVERLOAD cmul(GF31 a, GF31 b) { #else GF31 OVERLOAD cmul(GF31 a, GF31 b) { u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 - u64 k2 = a.x * (u64) (b.y + neg(b.x)); - u64 k3 = neg(a.y) * (u64) (b.y + b.x); - u64 k1k3 = k1 + k3; // unsigned 64-bit value, max = FFFF FFFC 0000 0004 - u64 k1k2 = k1 + k2; // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + u64 k1k2 = mad32(a.x, b.y + neg(b.x), k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + u64 k1k3 = mad32(neg(a.y), b.y + b.x, k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 return U2(modM31(k1k3), modM31(k1k2)); } #endif @@ -779,12 +801,33 @@ GF61 OVERLOAD csq(GF61 a) { return csqs(a, 2); } GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(modM61(weakMul(a.x + a.y, a.x + neg(a.y, 2), 3, 4) + c.x), modM61(weakMul(a.x + a.x, a.y, 3, 2) + c.y)); } // Complex mul +#if 0 GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 3-epsilon extra bits in u64 Z61 k1 = weakMul(b.x, a.x + a.y, 2, 3); // max value is 3*M61+epsilon Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2), 2, 3); // max value is 3*M61+epsilon Z61 k3 = weakMul(a.y, b.y + b.x, 2, 3); // max value is 3*M61+epsilon return U2(modM61(k1 + neg(k3, 4)), modM61(k1 + k2)); } +#else +Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b_m61_count) { + u128 ab = mad64(a, b, c); // Max c value assumed to be M61^2+epsilon + u64 lo = ab, hi = ab >> 64; + u64 lo61 = lo & M61; // Max value is M61 + if ((a_m61_count - 1) * (b_m61_count - 1) + 1 <= 6) { + hi = (hi << 3) + (lo >> 61); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 1) * M61 + epsilon + return lo61 + hi; // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 2) * M61 + epsilon + } else { + u64 hi61 = ((hi << 3) + (lo >> 61)) & M61; // Max value is M61 + return lo61 + hi61 + (hi >> 58); // Max value is 2*M61 + epsilon + } +} +GF61 OVERLOAD cmul(GF61 a, GF61 b) { + u128 k1 = (u128) b.x * (u128) (a.x + a.y); // max value is M61^2+epsilon + Z61 k1k2 = weakMulAdd(a.x, b.y + neg(b.x, 2), k1, 2, 3); // max value is 4*M61+epsilon + Z61 k1k3 = weakMulAdd(a.y, neg(b.y + b.x, 3), k1, 2, 4); // max value is 5*M61+epsilon + return U2(modM61(k1k3), modM61(k1k2)); +} +#endif // Square a root of unity complex number (the second version may be faster if the compiler optimizes the u128 squaring). //GF61 OVERLOAD csqTrig(GF61 a) { Z61 two_ay = a.y + a.y; return U2(modM61(1 + weakMul(two_ay, neg(a.y, 2))), mul(a.x, two_ay)); } From bd5e35dd60631bb28d0d8ff5b62a5a99a8f20d8a Mon Sep 17 00:00:00 2001 From: george Date: Wed, 15 Oct 2025 22:07:07 +0000 Subject: [PATCH 058/115] Changed the min BPW for FP32-only FFTs from 3.0 to 1.0. FP32-only FFTs are still real flaky. --- src/FFTConfig.h | 4 ++-- src/Gpu.cpp | 6 ++---- src/cl/carryutil.cl | 9 +++++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/FFTConfig.h b/src/FFTConfig.h index ece9fcec..c873eb8c 100644 --- a/src/FFTConfig.h +++ b/src/FFTConfig.h @@ -25,8 +25,6 @@ enum FFT_TYPES {FFT64=0, FFT3161=1, FFT3261=2, FFT61=3, FFT323161=4, FFT3231=50, class FFTShape { public: - static constexpr const float MIN_BPW = 3; - static std::vector allShapes(u32 from=0, u32 to = -1); static tuple getChainLengths(u32 fftSize, u32 exponent, u32 middle); @@ -48,6 +46,7 @@ class FFTShape { u32 nW() const { return (width == 1024 || width == 256 /*|| width == 4096*/) ? 4 : 8; } u32 nH() const { return (height == 1024 || height == 256 /*|| height == 4096*/) ? 4 : 8; } + float minBpw() const { return fft_type != FFT32 ? 3.0f : 1.0f; } float maxBpw() const { return *max_element(bpw.begin(), bpw.end()); } std::string spec() const { return (fft_type ? to_string(fft_type) + ':' : "") + numberK(width) + ':' + numberK(middle) + ':' + numberK(height); } @@ -97,5 +96,6 @@ struct FFTConfig { u64 size() const { return shape.size(); } u64 maxExp() const { return maxBpw() * shape.size(); } + float minBpw() const { return shape.minBpw(); } float maxBpw() const; }; diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 703a4ab6..00a2d7e2 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -646,12 +646,10 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& } } -#if !FFT_FP32 - if (bitsPerWord < FFTShape::MIN_BPW) { - log("FFT size too large for exponent (%.2f bits/word < %.2f bits/word).\n", bitsPerWord, FFTShape::MIN_BPW); + if (bitsPerWord < fft.minBpw()) { + log("FFT size too large for exponent (%.2f bits/word < %.2f bits/word).\n", bitsPerWord, fft.minBpw()); throw "FFT size too large"; } -#endif useLongCarry = useLongCarry || (bitsPerWord < 10.0); diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index b029a0e3..9e30c41c 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -10,8 +10,12 @@ typedef i32 CFcarry; #endif // The carry for the non-fused CarryA, CarryB, CarryM kernels. -// Simply use large carry always as the split kernels are slow anyway (and seldomly used normally). +// Simply use largest possible carry always as the split kernels are slow anyway (and seldomly used normally). +#if COMBO_FFT || !(FFT_FP32 || NTT_GF31) typedef i64 CarryABM; +#else +typedef i32 CarryABM; +#endif /********************************/ /* Helper routines */ @@ -753,7 +757,8 @@ Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { #undef iCARRY #endif +#if COMBO_FFT || !(FFT_FP32 || NTT_GF31) #define iCARRY i64 #include "carryinc.cl" #undef iCARRY - +#endif From 30cfce8939b08756fea39cad44691d5cfeaa0f2a Mon Sep 17 00:00:00 2001 From: george Date: Wed, 15 Oct 2025 22:25:20 +0000 Subject: [PATCH 059/115] Fixed crash during setup of middle=1 when width and height are 256. Middle=1 FFTs still don't work. --- src/FFTConfig.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index 09b07562..8f9b7168 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -155,6 +155,7 @@ FFTShape::FFTShape(enum FFT_TYPES t, u32 w, u32 m, u32 h) : while (w >= 4*h) { w /= 2; h *= 2; } while (w < h || w < 256 || w == 2048) { w *= 2; h /= 2; } while (h < 256) { h *= 2; m /= 2; } + if (m == 1) m = 2; bpw = FFTShape{w, m, h}.bpw; for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) bpw[j] -= 0.05; // Assume this fft spec is worse than measured fft specs if (this->isFavoredShape()) { // Don't output this warning message for non-favored shapes (we expect the BPW info to be missing) From ce3c7a2e0ab392cbcb3686720c1dc5903d56cb2b Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 03:09:58 +0000 Subject: [PATCH 060/115] Added assert at Mihai suggestion --- src/Gpu.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 00a2d7e2..33b844ac 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1047,6 +1047,7 @@ void Gpu::writeWords(Buffer& buf, vector &words) { else { vector GPUdata; GPUdata.resize(words.size() / 2); + assert((words.size() & 1) == 0); for (u32 i = 0; i < words.size(); i += 2) { GPUdata[i/2] = ((i64) words[i+1] << 32) | (u32) words[i]; } From cc809c15f243ee21ca73a07a3bd206189587ba95 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 03:22:16 +0000 Subject: [PATCH 061/115] Prettier parts code at Mihai's suggestion --- src/FFTConfig.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index 8f9b7168..fdead171 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -54,8 +54,7 @@ vector FFTShape::multiSpec(const string& iniSpec) { auto parts = split(spec, ':'); if (parseInt(parts[0]) < 60) { // Look for a prefix specifying the FFT type fft_type = (enum FFT_TYPES) parseInt(parts[0]); - for (u32 i = 1; i < parts.size(); ++i) parts[i-1] = parts[i]; - parts.resize(parts.size() - 1); + parts = vector(next(parts.begin()), parts.end()); } assert(parts.size() <= 3); if (parts.size() == 3) { From a93af50c97f1af3b38b64dd032baed1d8a1597a1 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 03:29:17 +0000 Subject: [PATCH 062/115] Deleted test code using signed GF61 intermediates --- src/cl/fft4.cl | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/src/cl/fft4.cl b/src/cl/fft4.cl index 29eaa786..8b422ca9 100644 --- a/src/cl/fft4.cl +++ b/src/cl/fft4.cl @@ -206,7 +206,6 @@ void OVERLOAD fft4by(GF61 *u, u32 base, u32 step, u32 M) { #define A(k) u[(base + step * k) % M] -#if !TEST_SHL Z61 x0 = addq(A(0).x, A(2).x); // Max value is 2*M61+epsilon Z61 x2 = subq(A(0).x, A(2).x, 2); // Max value is 3*M61+epsilon Z61 y0 = addq(A(0).y, A(2).y); @@ -234,42 +233,6 @@ void OVERLOAD fft4by(GF61 *u, u32 base, u32 step, u32 M) { A(2) = U2(a1, b1); A(3) = U2(a3, b3); -#else // Test case to see if signed M61 mod would be faster (if so, look into creating X2q options in math.cl's GF61 to support signed intermediates) - - i64 x0 = A(0).x + A(2).x; - i64 x2 = A(0).x - A(2).x; - i64 y0 = A(0).y + A(2).y; - i64 y2 = A(0).y - A(2).y; - - i64 x1 = A(1).x + A(3).x; - i64 y3 = A(1).x - A(3).x; - i64 y1 = A(1).y + A(3).y; - i64 x3 = A(3).y - A(1).y; - - i64 a0 = x0 - x1; - i64 a1 = x0 - x1; - - i64 b0 = y0 + y1; - i64 b1 = y0 - y1; - - i64 a2 = x2 + x3; - i64 a3 = x2 - x3; - - i64 b2 = y2 + y3; - i64 b3 = y2 - y3; - -#define cvt(a) (Z61) ((a & M61) + (a >> MBITS)) - - A(0) = U2(cvt(a0), cvt(b0)); - A(1) = U2(cvt(a2), cvt(b2)); - A(2) = U2(cvt(a1), cvt(b1)); - A(3) = U2(cvt(a3), cvt(b3)); - -#undef cvt - -#endif - - #undef A } From 2383424b792f6ee44cf2ec2c7f2be184f76c4df1 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 04:39:02 +0000 Subject: [PATCH 063/115] Fixed bug in ROE calculations for M31+M61 NTT --- src/cl/carry.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/carry.cl b/src/cl/carry.cl index 26bc6e91..93643311 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -558,7 +558,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u carryOut[G_W * g + me] = carry; #if ROE - float fltRoundMax = (float) roundMax / (float) 0x0FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float. + float fltRoundMax = (float) roundMax / (float) 0x1FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float. updateStats(bufROE, posROE, fltRoundMax); #elif (STATS & (1 << (2 + MUL3))) updateStats(bufROE, posROE, carryMax); From 09e0c23cdd143759f0b37f6c22be3841358c6c6e Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 19:12:11 +0000 Subject: [PATCH 064/115] Pass FFT_TYPE to OpenCL code -- makes carryfused, fftp, carryutil, etc. more readable --- src/FFTConfig.cpp | 16 ++++++++-------- src/Gpu.cpp | 3 ++- src/cl/base.cl | 19 ++++++++++++------- src/cl/carry.cl | 18 +++++++++--------- src/cl/carryfused.cl | 18 +++++++++--------- src/cl/carryinc.cl | 18 +++++++++--------- src/cl/carryutil.cl | 24 ++++++++++++------------ src/cl/fftp.cl | 18 +++++++++--------- 8 files changed, 70 insertions(+), 64 deletions(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index fdead171..f60d5d30 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -235,15 +235,15 @@ FFTConfig::FFTConfig(FFTShape shape, u32 variant, u32 carry) : assert(variant_M(variant) < N_VARIANT_M); assert(variant_H(variant) < N_VARIANT_H); - if (shape.fft_type == FFT64) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; - else if (shape.fft_type == FFT3161) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; - else if (shape.fft_type == FFT3261) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 8; - else if (shape.fft_type == FFT61) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 4; + if (shape.fft_type == FFT64) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT3161) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT3261) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT61) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 4; else if (shape.fft_type == FFT323161) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; - else if (shape.fft_type == FFT3231) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; - else if (shape.fft_type == FFT6431) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 8; - else if (shape.fft_type == FFT31) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; - else if (shape.fft_type == FFT32) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT3231) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT6431) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 8; + else if (shape.fft_type == FFT31) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT32) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; else throw "FFT type"; } diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 33b844ac..deab198c 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -331,7 +331,8 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< defines += toDefine("TAILTGF61", root1GF61(fft.shape.height * 2, 1)); } - // Enable/disable code for each possible FP and NTT + // Send the FFT/NTT type and booleans that enable/disable code for each possible FP and NTT + defines += toDefine("FFT_TYPE", (int) fft.shape.fft_type); defines += toDefine("FFT_FP64", (int) fft.FFT_FP64); defines += toDefine("FFT_FP32", (int) fft.FFT_FP32); defines += toDefine("NTT_GF31", (int) fft.NTT_GF31); diff --git a/src/cl/base.cl b/src/cl/base.cl index 781e832a..2611d281 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -186,13 +186,18 @@ typedef ulong2 GF61; // A complex value using two Z61s. For a GF(M61^2) //typedef ulong NCW; // A value calculated mod 2^64 - 2^32 + 1. //typedef ulong2 NCW2; // A complex value using NCWs. For a Nick Craig-Wood's insipred NTT using prime 2^64 - 2^32 + 1. -// Typedefs for "combo" FFT/NTTs (multiple NTT primes or hybrid FFT/NTT). -#define COMBO_FFT (FFT_FP64 + FFT_FP32 + NTT_GF31 + NTT_GF61 > 1) -// Sanity check for supported FFT/NTT -#if (FFT_FP64 & NTT_GF31 & !FFT_FP32 & !NTT_GF61) | (NTT_GF31 & NTT_GF61 & !FFT_FP64 & !FFT_FP32) | (FFT_FP32 & NTT_GF61 & !FFT_FP64 & !NTT_GF31) | (FFT_FP32 & NTT_GF31 & NTT_GF61 & !FFT_FP64) -#elif !COMBO_FFT | (FFT_FP32 & NTT_GF31 & !FFT_FP64 & !NTT_GF61) -#else -error - unsupported FFT/NTT combination +// Defines for the various supported FFTs/NTTs. These match the enumeration in FFTConfig.h. Sanity check for supported FFT/NTT. +#define FFT64 0 +#define FFT3161 1 +#define FFT3261 2 +#define FFT61 3 +#define FFT323161 4 +#define FFT3231 50 +#define FFT6431 51 +#define FFT31 52 +#define FFT32 53 +#if FFT_TYPE < 0 || (FFT_TYPE > 4 && FFT_TYPE < 50) || FFT_TYPE > 53 +#error - unsupported FFT/NTT #endif // Word and Word2 define the data type for FFT integers passed between the CPU and GPU. #if WordSize == 8 diff --git a/src/cl/carry.cl b/src/cl/carry.cl index 93643311..c1ebbf3a 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -3,7 +3,7 @@ #include "carryutil.cl" #include "weight.cl" -#if FFT_FP64 & !COMBO_FFT +#if FFT_TYPE == FFT64 // Carry propagation with optional MUL-3, over CARRY_LEN words. // Input arrives with real and imaginary values swapped and weighted. @@ -49,7 +49,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for an FFT based on FP32 */ /**************************************************************************/ -#elif FFT_FP32 & !COMBO_FFT +#elif FFT_TYPE == FFT32 // Carry propagation with optional MUL-3, over CARRY_LEN words. // Input arrives with real and imaginary values swapped and weighted. @@ -94,7 +94,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(F2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ -#elif NTT_GF31 & !COMBO_FFT +#elif FFT_TYPE == FFT31 KERNEL(G_W) carry(P(Word2) out, CP(GF31) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { u32 g = get_group_id(0); @@ -168,7 +168,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(GF31) in, u32 posROE, P(CarryABM) carryOut, P /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF61 & !COMBO_FFT +#elif FFT_TYPE == FFT61 KERNEL(G_W) carry(P(Word2) out, CP(GF61) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { u32 g = get_group_id(0); @@ -242,7 +242,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(GF61) in, u32 posROE, P(CarryABM) carryOut, P /* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP64 & NTT_GF31 +#elif FFT_TYPE == FFT6431 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTab THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); @@ -320,7 +320,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 +#elif FFT_TYPE == FFT3231 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); @@ -399,7 +399,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3261 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); @@ -478,7 +478,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3161 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { u32 g = get_group_id(0); @@ -570,7 +570,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u /* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ /******************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT323161 KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index fe788df2..3df1e961 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -16,7 +16,7 @@ void spin() { #endif } -#if FFT_FP64 & !COMBO_FFT +#if FFT_TYPE == FFT64 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -231,7 +231,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for an FFT based on FP32 */ /**************************************************************************/ -#elif FFT_FP32 & !COMBO_FFT +#elif FFT_TYPE == FFT32 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -421,7 +421,7 @@ KERNEL(G_W) carryFused(P(F2) out, CP(F2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ -#elif NTT_GF31 & !COMBO_FFT +#elif FFT_TYPE == FFT31 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -642,7 +642,7 @@ KERNEL(G_W) carryFused(P(GF31) out, CP(GF31) in, u32 posROE, P(i64) carryShuttle /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF61 & !COMBO_FFT +#elif FFT_TYPE == FFT61 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -863,7 +863,7 @@ KERNEL(G_W) carryFused(P(GF61) out, CP(GF61) in, u32 posROE, P(i64) carryShuttle /* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP64 & NTT_GF31 +#elif FFT_TYPE == FFT6431 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1106,7 +1106,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 +#elif FFT_TYPE == FFT3231 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1352,7 +1352,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3261 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1598,7 +1598,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3161 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) @@ -1857,7 +1857,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( /* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ /******************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT323161 // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index a8d36307..00bd1976 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -13,7 +13,7 @@ Word2 OVERLOAD carryFinal(Word2 u, iCARRY inCarry, bool b1) { /* Original FP64 version to start the carry propagation process for a pair of FFT values */ /*******************************************************************************************/ -#if FFT_FP64 & !COMBO_FFT +#if FFT_TYPE == FFT64 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -43,7 +43,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, bool b1 /* Similar to above, but for an FFT based on FP32 */ /**************************************************************************/ -#elif FFT_FP32 & !COMBO_FFT +#elif FFT_TYPE == FFT32 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -73,7 +73,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 u, F2 invWeight, iCARRY inCarry, bool /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ -#elif NTT_GF31 & !COMBO_FFT +#elif FFT_TYPE == FFT31 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -103,7 +103,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u, u32 invWeight1, u32 invWeight2, /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF61 & !COMBO_FFT +#elif FFT_TYPE == FFT61 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -133,7 +133,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(GF61 u, u32 invWeight1, u32 invWeight2, /* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP64 & NTT_GF31 +#elif FFT_TYPE == FFT6431 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -165,7 +165,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, GF31 u31, T invWeight1, T invWeigh /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 +#elif FFT_TYPE == FFT3231 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -197,7 +197,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, F invWeight1, F invWei /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3261 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -229,7 +229,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWei /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3161 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. @@ -260,7 +260,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, /* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ /******************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT323161 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 9e30c41c..6ae6c9f6 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -11,7 +11,7 @@ typedef i32 CFcarry; // The carry for the non-fused CarryA, CarryB, CarryM kernels. // Simply use largest possible carry always as the split kernels are slow anyway (and seldomly used normally). -#if COMBO_FFT || !(FFT_FP32 || NTT_GF31) +#if FFT_TYPE != FFT32 && FFT_TYPE != FFT31 typedef i64 CarryABM; #else typedef i32 CarryABM; @@ -151,7 +151,7 @@ void ROUNDOFF_CHECK(double x) { /* From the FFT data, construct a value to normalize and carry propagate */ /***************************************************************************/ -#if FFT_FP64 & !COMBO_FFT +#if FFT_TYPE == FFT64 #define SLOPPY_MAXBPW 173 // Based on 142.4M expo in 7.5M FFT = 18.36 BPW @@ -194,7 +194,7 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r /* Similar to above, but for an FFT based on FP32 */ /**************************************************************************/ -#elif FFT_FP32 & !COMBO_FFT +#elif FFT_TYPE == FFT32 #define SLOPPY_MAXBPW 0 // F32 FFTs are not practical @@ -235,7 +235,7 @@ i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_r /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ -#elif NTT_GF31 & !COMBO_FFT +#elif FFT_TYPE == FFT31 #define SLOPPY_MAXBPW 73 // Based on 140M expo in 16M FFT = 8.34 BPW @@ -263,7 +263,7 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF61 & !COMBO_FFT +#elif FFT_TYPE == FFT61 #define SLOPPY_MAXBPW 225 // Based on 198M expo in 8M FFT = 23.6 BPW @@ -291,7 +291,7 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { /* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP64 & NTT_GF31 +#elif FFT_TYPE == FFT6431 #define SLOPPY_MAXBPW 327 // Based on 142M expo in 4M FFT = 33.86 BPW @@ -329,7 +329,7 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 +#elif FFT_TYPE == FFT3231 #define SLOPPY_MAXBPW 154 // Based on 138M expo in 8M FFT = 16.45 BPW @@ -362,7 +362,7 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3261 #define SLOPPY_MAXBPW 309 // Based on 134M expo in 4M FFT = 31.95 BPW @@ -402,7 +402,7 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3161 #define SLOPPY_MAXBPW 383 // Based on 165M expo in 4M FFT = 39.34 BPW @@ -445,7 +445,7 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 /* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ /******************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT323161 #define SLOPPY_MAXBPW 461 // Based on 198M expo in 4M FFT = 47.20 BPW @@ -749,7 +749,7 @@ Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { /* Do this last, it depends on weightAndCarryOne defined above */ /**************************************************************************/ -/* Support both 32-bit and 64-bit carries */ // GWBUG - not all NTTs need to support both carries +/* Support both 32-bit and 64-bit carries */ #if WordSize <= 4 #define iCARRY i32 @@ -757,7 +757,7 @@ Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { #undef iCARRY #endif -#if COMBO_FFT || !(FFT_FP32 || NTT_GF31) +#if FFT_TYPE != FFT32 && FFT_TYPE != FFT31 #define iCARRY i64 #include "carryinc.cl" #undef iCARRY diff --git a/src/cl/fftp.cl b/src/cl/fftp.cl index 6a0f65fe..d430f984 100644 --- a/src/cl/fftp.cl +++ b/src/cl/fftp.cl @@ -6,7 +6,7 @@ #include "fftwidth.cl" #include "middle.cl" -#if FFT_FP64 & !COMBO_FFT +#if FFT_TYPE == FFT64 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) { @@ -37,7 +37,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) /* Similar to above, but for an FFT based on FP32 */ /**************************************************************************/ -#elif FFT_FP32 & !COMBO_FFT +#elif FFT_TYPE == FFT32 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(F2) out, CP(Word2) in, TrigFP32 smallTrig, BigTabFP32 THREAD_WEIGHTS) { @@ -68,7 +68,7 @@ KERNEL(G_W) fftP(P(F2) out, CP(Word2) in, TrigFP32 smallTrig, BigTabFP32 THREAD_ /* Similar to above, but for an NTT based on GF(M31^2) */ /**************************************************************************/ -#elif NTT_GF31 & !COMBO_FFT +#elif FFT_TYPE == FFT31 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(GF31) out, CP(Word2) in, TrigGF31 smallTrig) { @@ -123,7 +123,7 @@ KERNEL(G_W) fftP(P(GF31) out, CP(Word2) in, TrigGF31 smallTrig) { /* Similar to above, but for an NTT based on GF(M61^2) */ /**************************************************************************/ -#elif NTT_GF61 & !COMBO_FFT +#elif FFT_TYPE == FFT61 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(GF61) out, CP(Word2) in, TrigGF61 smallTrig) { @@ -180,7 +180,7 @@ KERNEL(G_W) fftP(P(GF61) out, CP(Word2) in, TrigGF61 smallTrig) { /* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP64 & NTT_GF31 +#elif FFT_TYPE == FFT6431 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) { @@ -248,7 +248,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) /* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ /**************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & !NTT_GF61 +#elif FFT_TYPE == FFT3231 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { @@ -318,7 +318,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIG /* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ /**************************************************************************/ -#elif FFT_FP32 & !NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3261 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { @@ -388,7 +388,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIG /* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ /**************************************************************************/ -#elif !FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT3161 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { @@ -471,7 +471,7 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { /* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ /******************************************************************************/ -#elif FFT_FP32 & NTT_GF31 & NTT_GF61 +#elif FFT_TYPE == FFT323161 // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { From 0d1a6240bdf8af3cda0fd17949a352f6c65236da Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 19:34:08 +0000 Subject: [PATCH 065/115] Compute FRAC_BITS_BIGSTEP in openCL code (like GF31 and GF61 NTTs do) rather than in C code. --- src/Gpu.cpp | 2 -- src/cl/carryfused.cl | 16 +++++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index deab198c..b1d6390d 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -405,8 +405,6 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< bpw--; // bpw must not be an exact value -- it must be less than exact value to get last biglit value right defines += toDefine("FRAC_BPW_HI", (u32) (bpw >> 32)); defines += toDefine("FRAC_BPW_LO", (u32) bpw); - u32 bigstep = (bpw * (N / fft.shape.nW())) >> 32; - defines += toDefine("FRAC_BITS_BIGSTEP", bigstep); return defines; } diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index 3df1e961..22d55ad8 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -91,20 +91,21 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Calculate the most significant 32-bits of FRAC_BPW * the word index. Also add FRAC_BPW_HI to test first biglit flag. u32 word_index = (me * H + line) * 2; u32 frac_bits = word_index * FRAC_BPW_HI + mad_hi (word_index, FRAC_BPW_LO, FRAC_BPW_HI); + const u32 frac_bits_bigstep = ((G_W * H * 2) * FRAC_BPW_HI + (u32)(((u64)(G_W * H * 2) * FRAC_BPW_LO) >> 32)); #endif // Apply the inverse weights and carry propagate pairs to generate the output carries T invBase = optionalDouble(weights.x); - + for (u32 i = 0; i < NW; ++i) { T invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); T invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); // Generate big-word/little-word flags #if BIGLIT - bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; - bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP >= -FRAC_BPW_HI; // Same as frac_bits + i * FRAC_BITS_BIGSTEP + FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; + bool biglit1 = frac_bits + i * frac_bits_bigstep >= -FRAC_BPW_HI; // Same as frac_bits + i * frac_bits_bigstep + FRAC_BPW_HI <= FRAC_BPW_HI; #else bool biglit0 = test(b, 2 * i); bool biglit1 = test(b, 2 * i + 1); @@ -210,7 +211,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Apply each 32 or 64 bit carry to the 2 words for (i32 i = 0; i < NW; ++i) { #if BIGLIT - bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; #else bool biglit0 = test(b, 2 * i); #endif @@ -291,6 +292,7 @@ KERNEL(G_W) carryFused(P(F2) out, CP(F2) in, u32 posROE, P(i64) carryShuttle, P( // Calculate the most significant 32-bits of FRAC_BPW * the word index. Also add FRAC_BPW_HI to test first biglit flag. u32 word_index = (me * H + line) * 2; u32 frac_bits = word_index * FRAC_BPW_HI + mad_hi (word_index, FRAC_BPW_LO, FRAC_BPW_HI); + const u32 frac_bits_bigstep = ((G_W * H * 2) * FRAC_BPW_HI + (u32)(((u64)(G_W * H * 2) * FRAC_BPW_LO) >> 32)); // Apply the inverse weights and carry propagate pairs to generate the output carries @@ -301,8 +303,8 @@ KERNEL(G_W) carryFused(P(F2) out, CP(F2) in, u32 posROE, P(i64) carryShuttle, P( F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); // Generate big-word/little-word flags - bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; - bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP >= -FRAC_BPW_HI; // Same as frac_bits + i * FRAC_BITS_BIGSTEP + FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; + bool biglit1 = frac_bits + i * frac_bits_bigstep >= -FRAC_BPW_HI; // Same as frac_bits + i * frac_bits_bigstep + FRAC_BPW_HI <= FRAC_BPW_HI; // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will @@ -403,7 +405,7 @@ KERNEL(G_W) carryFused(P(F2) out, CP(F2) in, u32 posROE, P(i64) carryShuttle, P( // Apply each 32 or 64 bit carry to the 2 words for (i32 i = 0; i < NW; ++i) { - bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; wu[i] = carryFinal(wu[i], carry[i], biglit0); u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); } From 6204f3d196d37b75de50107aaea74fae02038743 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 16 Oct 2025 23:45:38 +0000 Subject: [PATCH 066/115] Faster startup when beginning a new PRP test (long overdue for developers) --- src/Gpu.cpp | 40 +++++++++++++++++++++------------------- src/Gpu.h | 2 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index b1d6390d..b1ebd23c 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -962,32 +962,34 @@ void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { mul(ioA, buf1, buf2, buf3, mul3); }; -void Gpu::writeState(const vector& check, u32 blockSize) { +void Gpu::writeState(u32 k, const vector& check, u32 blockSize) { assert(blockSize > 0); writeIn(bufCheck, check); bufData << bufCheck; bufAux << bufCheck; - u32 n; - for (n = 1; blockSize % (2 * n) == 0; n *= 2) { - squareLoop(bufData, 0, n); - modMul(bufData, bufAux); - bufAux << bufData; - } + if (k) { // Only verify bufData that was read in from a save file + u32 n; + for (n = 1; blockSize % (2 * n) == 0; n *= 2) { + squareLoop(bufData, 0, n); + modMul(bufData, bufAux); + bufAux << bufData; + } - assert((n & (n - 1)) == 0); - assert(blockSize % n == 0); + assert((n & (n - 1)) == 0); + assert(blockSize % n == 0); - blockSize /= n; - assert(blockSize >= 2); + blockSize /= n; + assert(blockSize >= 2); + + for (u32 i = 0; i < blockSize - 2; ++i) { + squareLoop(bufData, 0, n); + modMul(bufData, bufAux); + } - for (u32 i = 0; i < blockSize - 2; ++i) { squareLoop(bufData, 0, n); - modMul(bufData, bufAux); } - - squareLoop(bufData, 0, n); modMul(bufData, bufAux, true); } @@ -1439,7 +1441,7 @@ PRPState Gpu::loadPRP(Saver& saver) { } PRPState state = saver.load(); - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); u64 res = dataResidue(); if (res == state.res64) { @@ -1482,7 +1484,7 @@ tuple Gpu::measureCarry() { u32 k = 0; PRPState state{E, 0, blockSize, 3, makeWords(E, 1), 0}; - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); { u64 res = dataResidue(); if (res != state.res64) { @@ -1550,7 +1552,7 @@ tuple Gpu::measureROE(bool quick) { u32 k = 0; PRPState state{E, 0, blockSize, 3, makeWords(E, 1), 0}; - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); { u64 res = dataResidue(); if (res != state.res64) { @@ -1619,7 +1621,7 @@ double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest u32 k = 0; PRPState state{E, 0, blockSize, 3, makeWords(E, 1), 0}; - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); assert(dataResidue() == state.res64); modMul(bufCheck, bufData); diff --git a/src/Gpu.h b/src/Gpu.h index 65708922..230e8059 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -259,7 +259,7 @@ class Gpu { void bottomHalf(Buffer& out, Buffer& inTmp); - void writeState(const vector& check, u32 blockSize); + void writeState(u32 k, const vector& check, u32 blockSize); // does either carrryFused() or the expanded version depending on useLongCarry void doCarry(Buffer& out, Buffer& in); From 5fa3a83187667cf1c86a7ef153f837d183db97a1 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 17 Oct 2025 03:37:16 +0000 Subject: [PATCH 067/115] Saved one negation in hybrid carry propagation --- src/cl/carryutil.cl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 6ae6c9f6..4346cd16 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -374,10 +374,9 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, u64 n61 = get_Z61(u61); // The final result must be n61 mod M61. Use FP32 data to calculate this value. -// float n61f = (float)n61; // Convert n61 to float - float n61f = (float)((u32)(n61 >> 32)) * 4294967296.0f; // Conversion from u64 to float might be slow, this might be faster - uF2 = fma(uF2, F2_invWeight, -n61f); // This should be close to a multiple of M61 - float uF2int = fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL); // Divide by 2^61 and round to int + float n61f = (float)((u32)(n61 >> 32)) * -4294967296.0f; // Conversion from u64 to float might be slow, this might be faster + uF2 = fma(uF2, F2_invWeight, n61f); // This should be close to a multiple of M61 + float uF2int = fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL); // Divide by M61 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); // Optionally calculate roundoff error @@ -464,9 +463,8 @@ i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_ i128 n3161 = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 // The final result must be n3161 mod M31*M61. Use FP32 data to calculate this value. -// float n3161f = (float)n3161; // Convert n3161 to float - float n3161f = (float)((u32)(n61 >> 32)) * 9223372036854775808.0f; // Conversion from i128 to float might be slow, this might be faster - uF2 = fma(uF2, F2_invWeight, -n3161f); // This should be close to a multiple of M31*M61 + float n3161f = (float)((u32)(n61 >> 32)) * -9223372036854775808.0f; // Conversion from i128 to float might be slow, this might be faster + uF2 = fma(uF2, F2_invWeight, n3161f); // This should be close to a multiple of M31*M61 float uF2int = fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL); // Divide by M31*M61 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); From 25522f7fe991da44f427956c7bb0edd192c98e05 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 17 Oct 2025 17:23:23 +0000 Subject: [PATCH 068/115] Filled out the max BPW table. Allow Z=6 for NTTs without any warning. --- src/Gpu.cpp | 2 +- src/fftbpw.h | 232 +++++++++++++++++++++++++-------------------------- 2 files changed, 116 insertions(+), 118 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index b1ebd23c..3ca77b04 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1283,7 +1283,7 @@ void Gpu::doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u3 log("%sZ=%.0f (avg %.1f) %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), z, zAvg.avg(), (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); - if (roeSq.N > 2 && z < 20) { + if (roeSq.N > 2 && (z < 6 || (fft.shape.fft_type == FFT64 && z < 20))) { log("Danger ROE! Z=%.1f is too small, increase precision or FFT size!\n", z); } diff --git a/src/fftbpw.h b/src/fftbpw.h index 86392e3d..c62db530 100644 --- a/src/fftbpw.h +++ b/src/fftbpw.h @@ -1,3 +1,4 @@ +// FFT64 - Computed by targeting Z=28 { "256:2:256", {19.204, 19.547, 19.636, 19.204, 19.547, 19.636}}, { "256:3:256", {19.106, 19.386, 19.369, 19.106, 19.399, 19.361}}, { "256:4:256", {18.928, 19.236, 19.322, 18.954, 19.272, 19.367}}, @@ -91,120 +92,117 @@ { "4K:14:1K", {16.942, 17.055, 17.160, 17.027, 17.127, 17.306}}, { "4K:15:1K", {17.021, 17.007, 17.137, 17.104, 17.087, 17.282}}, { "4K:16:1K", {16.744, 16.887, 16.966, 16.921, 17.048, 17.208}}, -// FFT3161 -{ "1:256:2:256", {39.74, 39.74, 39.74, 39.74, 39.74, 39.74}}, -{ "1:256:4:256", {39.64, 39.64, 39.64, 39.64, 39.64, 39.64}}, -{ "1:256:8:256", {39.54, 39.54, 39.54, 39.54, 39.54, 39.54}}, -{ "1:512:4:256", {39.54, 39.54, 39.54, 39.54, 39.54, 39.54}}, -{"1:256:16:256", {39.44, 39.44, 39.44, 39.44, 39.44, 39.44}}, -{ "1:512:8:256", {39.44, 39.44, 39.44, 39.44, 39.44, 39.44}}, -{ "1:512:4:512", {39.44, 39.44, 39.44, 39.44, 39.44, 39.44}}, -{ "1:1K:8:256", {39.34, 39.34, 39.34, 39.34, 39.34, 39.34}}, -{"1:512:16:256", {39.34, 39.34, 39.34, 39.34, 39.34, 39.34}}, -{ "1:512:8:512", {39.34, 39.34, 39.34, 39.34, 39.34, 39.34}}, -{ "1:1K:16:256", {39.24, 39.24, 39.24, 39.24, 39.24, 39.24}}, -{ "1:1K:8:512", {39.24, 39.24, 39.24, 39.24, 39.24, 39.24}}, -{"1:512:16:512", {39.24, 39.24, 39.24, 39.24, 39.24, 39.24}}, -{ "1:1K:16:512", {39.14, 39.14, 39.14, 39.14, 39.14, 39.14}}, -{ "1:1K:8:1K", {39.14, 39.14, 39.14, 39.14, 39.14, 39.14}}, -{ "1:1K:16:1K", {39.04, 39.04, 39.04, 39.04, 39.04, 39.04}}, -{ "1:4K:16:512", {38.94, 38.94, 38.94, 38.94, 38.94, 38.94}}, -{ "1:4K:16:1K", {38.84, 38.84, 38.84, 38.84, 38.84, 38.84}}, -// FFT3261 -{ "2:256:2:256", {32.05, 32.05, 32.05, 32.05, 32.05, 32.05}}, -{ "2:256:4:256", {31.95, 31.95, 31.95, 31.95, 31.95, 31.95}}, -{ "2:256:8:256", {31.85, 31.85, 31.85, 31.85, 31.85, 31.85}}, -{ "2:512:4:256", {31.85, 31.85, 31.85, 31.85, 31.85, 31.85}}, -{"2:256:16:256", {31.75, 31.75, 31.75, 31.75, 31.75, 31.75}}, -{ "2:512:8:256", {31.75, 31.75, 31.75, 31.75, 31.75, 31.75}}, -{ "2:512:4:512", {31.75, 31.75, 31.75, 31.75, 31.75, 31.75}}, -{ "2:1K:8:256", {31.65, 31.65, 31.65, 31.65, 31.65, 31.65}}, -{"2:512:16:256", {31.65, 31.65, 31.65, 31.65, 31.65, 31.65}}, -{ "2:512:8:512", {31.65, 31.65, 31.65, 31.65, 31.65, 31.65}}, -{ "2:1K:16:256", {31.55, 31.55, 31.55, 31.55, 31.55, 31.55}}, -{ "2:1K:8:512", {31.55, 31.55, 31.55, 31.55, 31.55, 31.55}}, -{"2:512:16:512", {31.55, 31.55, 31.55, 31.55, 31.55, 31.55}}, -{ "2:1K:16:512", {31.45, 31.45, 31.45, 31.45, 31.45, 31.45}}, -{ "2:1K:8:1K", {31.45, 31.45, 31.45, 31.45, 31.45, 31.45}}, -{ "2:1K:16:1K", {31.35, 31.35, 31.35, 31.35, 31.35, 31.35}}, -{ "2:4K:16:512", {31.25, 31.25, 31.25, 31.25, 31.25, 31.25}}, -{ "2:4K:16:1K", {31.15, 31.15, 31.15, 31.15, 31.15, 31.15}}, -// FFT61 -{ "3:256:2:256", {24.20, 24.20, 24.20, 24.20, 24.20, 24.20}}, -{ "3:256:4:256", {24.10, 24.10, 24.10, 24.10, 24.10, 24.10}}, -{ "3:256:8:256", {24.00, 24.00, 24.00, 24.00, 24.00, 24.00}}, -{ "3:512:4:256", {24.00, 24.00, 24.00, 24.00, 24.00, 24.00}}, -{"3:256:16:256", {23.90, 23.90, 23.90, 23.90, 23.90, 23.90}}, -{ "3:512:8:256", {23.90, 23.90, 23.90, 23.90, 23.90, 23.90}}, -{ "3:512:4:512", {23.90, 23.90, 23.90, 23.90, 23.90, 23.90}}, -{ "3:1K:8:256", {23.80, 23.80, 23.80, 23.80, 23.80, 23.80}}, -{"3:512:16:256", {23.80, 23.80, 23.80, 23.80, 23.80, 23.80}}, -{ "3:512:8:512", {23.80, 23.80, 23.80, 23.80, 23.80, 23.80}}, -{ "3:1K:16:256", {23.70, 23.70, 23.70, 23.70, 23.70, 23.70}}, -{ "3:1K:8:512", {23.70, 23.70, 23.70, 23.70, 23.70, 23.70}}, -{"3:512:16:512", {23.70, 23.70, 23.70, 23.70, 23.70, 23.70}}, -{ "3:1K:16:512", {23.60, 23.60, 23.60, 23.60, 23.60, 23.60}}, -{ "3:1K:8:1K", {23.60, 23.60, 23.60, 23.60, 23.60, 23.60}}, -{ "3:1K:16:1K", {23.50, 23.50, 23.50, 23.50, 23.50, 23.50}}, -{ "3:4K:16:512", {23.40, 23.40, 23.40, 23.40, 23.40, 23.40}}, -{ "3:4K:16:1K", {23.30, 23.30, 23.30, 23.30, 23.30, 23.30}}, -// FFT323161 -{ "4:256:2:256", {47.65, 47.65, 47.65, 47.65, 47.65, 47.65}}, -{ "4:256:4:256", {47.55, 47.55, 47.55, 47.55, 47.55, 47.55}}, -{ "4:256:8:256", {47.45, 47.45, 47.45, 47.45, 47.45, 47.45}}, -{ "4:512:4:256", {47.45, 47.45, 47.45, 47.45, 47.45, 47.45}}, -{"4:256:16:256", {47.35, 47.35, 47.35, 47.35, 47.35, 47.35}}, -{ "4:512:8:256", {47.35, 47.35, 47.35, 47.35, 47.35, 47.35}}, -{ "4:512:4:512", {47.35, 47.35, 47.35, 47.35, 47.35, 47.35}}, -{ "4:1K:8:256", {47.25, 47.25, 47.25, 47.25, 47.25, 47.25}}, -{"4:512:16:256", {47.25, 47.25, 47.25, 47.25, 47.25, 47.25}}, -{ "4:512:8:512", {47.25, 47.25, 47.25, 47.25, 47.25, 47.25}}, -{ "4:1K:16:256", {47.15, 47.15, 47.15, 47.15, 47.15, 47.15}}, -{ "4:1K:8:512", {47.15, 47.15, 47.15, 47.15, 47.15, 47.15}}, -{"4:512:16:512", {47.15, 47.15, 47.15, 47.15, 47.15, 47.15}}, -{ "4:1K:16:512", {47.05, 47.05, 47.05, 47.05, 47.05, 47.05}}, -{ "4:1K:8:1K", {47.05, 47.05, 47.05, 47.05, 47.05, 47.05}}, -{ "4:1K:16:1K", {46.95, 46.95, 46.95, 46.95, 46.95, 46.95}}, -{ "4:4K:16:512", {46.85, 46.85, 46.85, 46.85, 46.85, 46.85}}, -{ "4:4K:16:1K", {46.75, 46.75, 46.75, 46.75, 46.75, 46.75}}, -// FFT3231 -{ "50:256:2:256", {16.95, 16.95, 16.95, 16.95, 16.95, 16.95}}, -{ "50:256:4:256", {16.85, 16.85, 16.85, 16.85, 16.85, 16.85}}, -{ "50:256:8:256", {16.75, 16.75, 16.75, 16.75, 16.75, 16.75}}, -{ "50:512:4:256", {16.75, 16.75, 16.75, 16.75, 16.75, 16.75}}, -{"50:256:16:256", {16.65, 16.65, 16.65, 16.65, 16.65, 16.65}}, -{ "50:512:8:256", {16.65, 16.65, 16.65, 16.65, 16.65, 16.65}}, -{ "50:512:4:512", {16.65, 16.65, 16.65, 16.65, 16.65, 16.65}}, -{ "50:1K:8:256", {16.55, 16.55, 16.55, 16.55, 16.55, 16.55}}, -{"50:512:16:256", {16.55, 16.55, 16.55, 16.55, 16.55, 16.55}}, -{ "50:512:8:512", {16.55, 16.55, 16.55, 16.55, 16.55, 16.55}}, -{ "50:1K:16:256", {16.45, 16.45, 16.45, 16.45, 16.45, 16.45}}, -{ "50:1K:8:512", {16.45, 16.45, 16.45, 16.45, 16.45, 16.45}}, -{"50:512:16:512", {16.45, 16.45, 16.45, 16.45, 16.45, 16.45}}, -{ "50:1K:16:512", {16.35, 16.35, 16.35, 16.35, 16.35, 16.35}}, -{ "50:1K:8:1K", {16.35, 16.35, 16.35, 16.35, 16.35, 16.35}}, -{ "50:1K:16:1K", {16.25, 16.25, 16.25, 16.25, 16.25, 16.25}}, -{ "50:4K:16:512", {16.15, 16.15, 16.15, 16.15, 16.15, 16.15}}, -{ "50:4K:16:1K", {16.05, 16.05, 16.05, 16.05, 16.05, 16.05}}, -// FFT6431 -{ "51:256:2:256", {34.26, 34.26, 34.26, 34.26, 34.26, 34.26}}, -{ "51:256:4:256", {34.16, 34.16, 34.16, 34.16, 34.16, 34.16}}, -{ "51:256:8:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, -{ "51:512:4:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, -{"51:256:16:256", {33.96, 33.96, 33.96, 33.96, 33.96, 33.96}}, -{ "51:512:8:256", {33.96, 33.96, 33.96, 33.96, 33.96, 33.96}}, -{ "51:512:4:512", {33.96, 33.96, 33.96, 33.96, 33.96, 33.96}}, -{ "51:1K:8:256", {33.86, 33.86, 33.86, 33.86, 33.86, 33.86}}, -{"51:512:16:256", {33.86, 33.86, 33.86, 33.86, 33.86, 33.86}}, -{ "51:512:8:512", {33.86, 33.86, 33.86, 33.86, 33.86, 33.86}}, -{ "51:1K:16:256", {33.76, 33.76, 33.76, 33.76, 33.76, 33.76}}, -{ "51:1K:8:512", {33.76, 33.76, 33.76, 33.76, 33.76, 33.76}}, -{"51:512:16:512", {33.76, 33.76, 33.76, 33.76, 33.76, 33.76}}, -{ "51:1K:16:512", {33.66, 33.66, 33.66, 33.66, 33.66, 33.66}}, -{ "51:1K:8:1K", {33.66, 33.66, 33.66, 33.66, 33.66, 33.66}}, -{ "51:1K:16:1K", {33.56, 33.56, 33.56, 33.56, 33.56, 33.56}}, -{ "51:4K:16:512", {33.46, 33.46, 33.46, 33.46, 33.46, 33.46}}, -{ "51:4K:16:1K", {33.36, 33.36, 33.36, 33.36, 33.36, 33.36}}, - - - +// FFT3161 - Computed by targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "1:256:2:256", {40.54, 40.54, 40.54, 40.54, 40.54, 40.54}}, +{ "1:256:4:256", {40.19, 40.19, 40.19, 40.19, 40.19, 40.19}}, +{ "1:256:8:256", {39.98, 39.98, 39.98, 39.98, 39.98, 39.98}}, +{ "1:512:4:256", {39.98, 39.98, 39.98, 39.98, 39.98, 39.98}}, +{"1:256:16:256", {39.67, 39.67, 39.67, 39.67, 39.67, 39.67}}, +{ "1:512:8:256", {39.67, 39.67, 39.67, 39.67, 39.67, 39.67}}, +{ "1:512:4:512", {39.67, 39.67, 39.67, 39.67, 39.67, 39.67}}, +{ "1:1K:8:256", {39.46, 39.46, 39.46, 39.46, 39.46, 39.46}}, +{"1:512:16:256", {39.46, 39.46, 39.46, 39.46, 39.46, 39.46}}, +{ "1:512:8:512", {39.46, 39.46, 39.46, 39.46, 39.46, 39.46}}, +{ "1:1K:16:256", {39.15, 39.15, 39.15, 39.15, 39.15, 39.15}}, +{ "1:1K:8:512", {39.15, 39.15, 39.15, 39.15, 39.15, 39.15}}, +{"1:512:16:512", {39.15, 39.15, 39.15, 39.15, 39.15, 39.15}}, +{ "1:1K:16:512", {38.97, 38.97, 38.97, 38.97, 38.97, 38.97}}, +{ "1:1K:8:1K", {38.97, 38.97, 38.97, 38.97, 38.97, 38.97}}, +{ "1:1K:16:1K", {38.62, 38.62, 38.62, 38.62, 38.62, 38.62}}, +{ "1:4K:16:512", {38.37, 38.37, 38.37, 38.37, 38.37, 38.37}}, +{ "1:4K:16:1K", {38.12, 38.12, 38.12, 38.12, 38.12, 38.12}}, // Estimated +// FFT3261 - Computed with -use TABMUL_CHAIN32=0,TAIL_TRIGS32=0 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "2:256:2:256", {34.57, 34.57, 34.57, 34.57, 34.57, 34.57}}, +{ "2:256:4:256", {34.24, 34.24, 34.24, 34.24, 34.24, 34.24}}, +{ "2:256:8:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{ "2:512:4:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{"2:256:16:256", {32.07, 32.07, 32.07, 32.07, 32.07, 32.07}}, +{ "2:512:8:256", {32.07, 32.07, 32.07, 32.07, 32.07, 32.07}}, +{ "2:512:4:512", {32.07, 32.07, 32.07, 32.07, 32.07, 32.07}}, +{ "2:1K:8:256", {31.81, 31.81, 31.81, 31.81, 31.81, 31.81}}, +{"2:512:16:256", {31.81, 31.81, 31.81, 31.81, 31.81, 31.81}}, +{ "2:512:8:512", {31.81, 31.81, 31.81, 31.81, 31.81, 31.81}}, +{ "2:1K:16:256", {31.50, 31.50, 31.50, 31.50, 31.50, 31.50}}, +{ "2:1K:8:512", {31.50, 31.50, 31.50, 31.50, 31.50, 31.50}}, +{"2:512:16:512", {31.50, 31.50, 31.50, 31.50, 31.50, 31.50}}, +{ "2:1K:16:512", {28.68, 28.68, 28.68, 28.68, 28.68, 28.68}}, // Very strange. 481421001 has a maxROE of 0.180, 481422001 has a maxROE of 0.5 +{ "2:1K:8:1K", {28.68, 28.68, 28.68, 28.68, 28.68, 28.68}}, +{ "2:1K:16:1K", {25.37, 25.37, 25.37, 25.37, 25.37, 25.37}}, // Also strange. 851422001 ROEmax=0.273, ROEavg=0.003 +{ "2:4K:16:512", {23.27, 23.27, 23.27, 23.27, 23.27, 23.27}}, +{ "2:4K:16:1K", {21.15, 21.15, 21.15, 21.15, 21.15, 21.15}}, // Estimated +// FFT61 - Computed by targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "3:256:2:256", {25.02, 25.02, 25.02, 25.02, 25.02, 25.02}}, +{ "3:256:4:256", {24.72, 24.10, 24.10, 24.10, 24.10, 24.10}}, +{ "3:256:8:256", {24.46, 24.46, 24.46, 24.46, 24.46, 24.46}}, +{ "3:512:4:256", {24.46, 24.46, 24.46, 24.46, 24.46, 24.46}}, +{"3:256:16:256", {24.15, 24.15, 24.15, 24.15, 24.15, 24.15}}, +{ "3:512:8:256", {24.15, 24.15, 24.15, 24.15, 24.15, 24.15}}, +{ "3:512:4:512", {24.15, 24.15, 24.15, 24.15, 24.15, 24.15}}, +{ "3:1K:8:256", {23.94, 23.94, 23.94, 23.94, 23.94, 23.94}}, +{"3:512:16:256", {23.94, 23.94, 23.94, 23.94, 23.94, 23.94}}, +{ "3:512:8:512", {23.94, 23.94, 23.94, 23.94, 23.94, 23.94}}, +{ "3:1K:16:256", {23.65, 23.65, 23.65, 23.65, 23.65, 23.65}}, +{ "3:1K:8:512", {23.65, 23.65, 23.65, 23.65, 23.65, 23.65}}, +{"3:512:16:512", {23.65, 23.65, 23.65, 23.65, 23.65, 23.65}}, +{ "3:1K:16:512", {23.42, 23.42, 23.42, 23.42, 23.42, 23.42}}, +{ "3:1K:8:1K", {23.42, 23.42, 23.42, 23.42, 23.42, 23.42}}, +{ "3:1K:16:1K", {23.13, 23.13, 23.13, 23.13, 23.13, 23.13}}, +{ "3:4K:16:512", {22.92, 22.92, 22.92, 22.92, 22.92, 22.92}}, +{ "3:4K:16:1K", {22.72, 22.72, 22.72, 22.72, 22.72, 22.72}}, // Estimated +// FFT323161 - Computed with -use TABMUL_CHAIN32=0,TAIL_TRIGS32=0 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "4:256:2:256", {50.05, 50.05, 50.05, 50.05, 50.05, 50.05}}, +{ "4:256:4:256", {49.76, 49.76, 49.76, 49.76, 49.76, 49.76}}, +{ "4:256:8:256", {49.59, 49.59, 49.59, 49.59, 49.59, 49.59}}, +{ "4:512:4:256", {49.59, 49.59, 49.59, 49.59, 49.59, 49.59}}, +{"4:256:16:256", {47.59, 47.59, 47.59, 47.59, 47.59, 47.59}}, +{ "4:512:8:256", {47.59, 47.59, 47.59, 47.59, 47.59, 47.59}}, +{ "4:512:4:512", {47.59, 47.59, 47.59, 47.59, 47.59, 47.59}}, +{ "4:1K:8:256", {47.33, 47.33, 47.33, 47.33, 47.33, 47.33}}, +{"4:512:16:256", {47.33, 47.33, 47.33, 47.33, 47.33, 47.33}}, +{ "4:512:8:512", {47.33, 47.33, 47.33, 47.33, 47.33, 47.33}}, +{ "4:1K:16:256", {47.00, 47.00, 47.00, 47.00, 47.00, 47.00}}, +{ "4:1K:8:512", {47.00, 47.00, 47.00, 47.00, 47.00, 47.00}}, +{"4:512:16:512", {47.00, 47.00, 47.00, 47.00, 47.00, 47.00}}, +{ "4:1K:16:512", {44.52, 44.52, 44.52, 44.52, 44.52, 44.52}}, +{ "4:1K:8:1K", {44.52, 44.52, 44.52, 44.52, 44.52, 44.52}}, +{ "4:1K:16:1K", {41.72, 41.72, 41.72, 41.72, 41.72, 41.72}}, // Strange 41.72 has tiny error, 41.75 is 0.5 +{ "4:4K:16:512", {39.50, 39.50, 39.50, 39.50, 39.50, 39.50}}, // Estimated +{ "4:4K:16:1K", {37.50, 37.50, 37.50, 37.50, 37.50, 37.50}}, // Estimated +// FFT3231 - Computed with -use TABMUL_CHAIN32=0,TAIL_TRIGS32=0 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "50:256:2:256", {19.57, 19.57, 19.57, 19.57, 19.57, 19.57}}, +{ "50:256:4:256", {19.23, 19.23, 19.23, 19.23, 19.23, 19.23}}, +{ "50:256:8:256", {19.07, 19.07, 19.07, 19.07, 19.07, 19.07}}, +{ "50:512:4:256", {19.07, 19.07, 19.07, 19.07, 19.07, 19.07}}, +{"50:256:16:256", {17.07, 17.07, 17.07, 17.07, 17.07, 17.07}}, +{ "50:512:8:256", {17.07, 17.07, 17.07, 17.07, 17.07, 17.07}}, +{ "50:512:4:512", {17.07, 17.07, 17.07, 17.07, 17.07, 17.07}}, +{ "50:1K:8:256", {16.78, 16.78, 16.78, 16.78, 16.78, 16.78}}, +{"50:512:16:256", {16.78, 16.78, 16.78, 16.78, 16.78, 16.78}}, +{ "50:512:8:512", {16.78, 16.78, 16.78, 16.78, 16.78, 16.78}}, +{ "50:1K:16:256", {16.52, 16.52, 16.52, 16.52, 16.52, 16.52}}, +{ "50:1K:8:512", {16.52, 16.52, 16.52, 16.52, 16.52, 16.52}}, +{"50:512:16:512", {16.52, 16.52, 16.52, 16.52, 16.52, 16.52}}, +{ "50:1K:16:512", {14.01, 14.01, 14.01, 14.01, 14.01, 14.01}}, +{ "50:1K:8:1K", {14.01, 14.01, 14.01, 14.01, 14.01, 14.01}}, +{ "50:1K:16:1K", {11.21, 11.21, 11.21, 11.21, 11.21, 11.21}}, // Estimated +{ "50:4K:16:512", {9.15, 9.15, 9.15, 9.15, 9.15, 9.15}}, // Estimated +{ "50:4K:16:1K", {7.05, 7.05, 7.05, 7.05, 7.05, 7.05}}, // Estimated +// FFT6431 - Computed with variant 202 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "51:256:2:256", {35.27, 35.27, 35.27, 35.27, 35.27, 35.27}}, +{ "51:256:4:256", {34.98, 34.98, 34.98, 34.98, 34.98, 34.98}}, +{ "51:256:8:256", {34.61, 34.61, 34.61, 34.61, 34.61, 34.61}}, +{ "51:512:4:256", {34.61, 34.61, 34.61, 34.61, 34.61, 34.61}}, +{"51:256:16:256", {34.33, 34.33, 34.33, 34.33, 34.33, 34.33}}, +{ "51:512:8:256", {34.33, 34.33, 34.33, 34.33, 34.33, 34.33}}, +{ "51:512:4:512", {34.33, 34.33, 34.33, 34.33, 34.33, 34.33}}, +{ "51:1K:8:256", {33.97, 33.97, 33.97, 33.97, 33.97, 33.97}}, +{"51:512:16:256", {33.97, 33.97, 33.97, 33.97, 33.97, 33.97}}, +{ "51:512:8:512", {33.97, 33.97, 33.97, 33.97, 33.97, 33.97}}, +{ "51:1K:16:256", {33.64, 33.64, 33.64, 33.64, 33.64, 33.64}}, +{ "51:1K:8:512", {33.64, 33.64, 33.64, 33.64, 33.64, 33.64}}, +{"51:512:16:512", {33.64, 33.64, 33.64, 33.64, 33.64, 33.64}}, +{ "51:1K:16:512", {33.45, 33.45, 33.45, 33.45, 33.45, 33.45}}, +{ "51:1K:8:1K", {33.45, 33.45, 33.45, 33.45, 33.45, 33.45}}, +{ "51:1K:16:1K", {33.23, 33.23, 33.23, 33.23, 33.23, 33.23}}, +{ "51:4K:16:512", {32.75, 32.75, 32.75, 32.75, 32.75, 32.75}}, +{ "51:4K:16:1K", {32.25, 32.25, 32.25, 32.25, 32.25, 32.25}}, // Estimated From fd0a895f4a6115ed763b1a29252c2545048dda10 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 18 Oct 2025 00:11:48 +0000 Subject: [PATCH 069/115] Eliminate fixed definition of SLOPPY_MAXBPW. C code now passes in the actual maxbpw. --- src/Gpu.cpp | 1 + src/cl/carryutil.cl | 36 ++++++++++++------------------------ 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 3ca77b04..8826925c 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -312,6 +312,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< u32 N = fft.shape.size(); defines += toDefine("FFT_VARIANT", fft.variant); + defines += toDefine("MAXBPW", (u32)(fft.maxBpw() * 100.0f)); if (fft.FFT_FP64 | fft.FFT_FP32) { defines += toDefine("WEIGHT_STEP", weightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 4346cd16..d530d33e 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -153,8 +153,6 @@ void ROUNDOFF_CHECK(double x) { #if FFT_TYPE == FFT64 -#define SLOPPY_MAXBPW 173 // Based on 142.4M expo in 7.5M FFT = 18.36 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_result_is_acceptable) { @@ -196,8 +194,6 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r #elif FFT_TYPE == FFT32 -#define SLOPPY_MAXBPW 0 // F32 FFTs are not practical - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_result_is_acceptable) { @@ -237,8 +233,6 @@ i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_r #elif FFT_TYPE == FFT31 -#define SLOPPY_MAXBPW 73 // Based on 140M expo in 16M FFT = 8.34 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { @@ -265,8 +259,6 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { #elif FFT_TYPE == FFT61 -#define SLOPPY_MAXBPW 225 // Based on 198M expo in 8M FFT = 23.6 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { @@ -293,8 +285,6 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { #elif FFT_TYPE == FFT6431 -#define SLOPPY_MAXBPW 327 // Based on 142M expo in 4M FFT = 33.86 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, float* maxROE) { @@ -331,8 +321,6 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, #elif FFT_TYPE == FFT3231 -#define SLOPPY_MAXBPW 154 // Based on 138M expo in 8M FFT = 16.45 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, i32 inCarry, float* maxROE) { @@ -364,8 +352,6 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, #elif FFT_TYPE == FFT3261 -#define SLOPPY_MAXBPW 309 // Based on 134M expo in 4M FFT = 31.95 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { @@ -403,8 +389,6 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, #elif FFT_TYPE == FFT3161 -#define SLOPPY_MAXBPW 383 // Based on 165M expo in 4M FFT = 39.34 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, u32* maxROE) { @@ -446,8 +430,6 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 #elif FFT_TYPE == FFT323161 -#define SLOPPY_MAXBPW 461 // Based on 198M expo in 4M FFT = 47.20 BPW - // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { @@ -640,8 +622,14 @@ Word OVERLOAD carryStepUnsignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { /* Also used on first word in carryFinal when not near max BPW. */ /**********************************************************************/ +// We only allow sloppy results when not near the maximum bits-per-word. For now, this is defined as 1.1 bits below maxbpw. +// No studies have been done on reducing this 1,1 value since this is a rather minor optimization. Since the preprocessor can't +// handle floats, the MAXBPW value passed in is 100 * maxbpw. +#define SLOPPY_MAXBPW (MAXBPW - 1100) +#define ACTUAL_BPW (EXP / (NWORDS / 100)) + Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { -#if EXP > NWORDS / 10 * SLOPPY_MAXBPW +#if ACTUAL_BPW > SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); #else @@ -658,7 +646,7 @@ Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { -#if EXP > NWORDS / 10 * SLOPPY_MAXBPW +#if ACTUAL_BPW > SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); #else @@ -672,7 +660,7 @@ Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { i64 xhi = i96_hi64(x) + (xlo_topbit >> 32); *outCarry = xhi >> (nBits - 32); return w; -#elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 320 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) +#elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 3200 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) i32 w = i96_lo32(x); // lowBits(x, bigwordBits = 32); *outCarry = (i96_hi64(x) + (w < 0)) << (32 - nBits); return w; @@ -685,7 +673,7 @@ Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { } Word OVERLOAD carryStepSignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { -#if EXP > NWORDS / 10 * SLOPPY_MAXBPW +#if ACTUAL_BPW > SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); #else @@ -700,7 +688,7 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { } Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { -#if EXP > NWORDS / 10 * SLOPPY_MAXBPW +#if ACTUAL_BPW > SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); #else @@ -716,7 +704,7 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance). For reasons I don't fully understand the sloppy // case fails if BPW is too low. Probably something to do with a small BPW with sloppy 32-bit values would require CARRY_LONG to work properly. // Not a major concern as end users should avoid small BPW as there is probably a more efficient NTT that could be used. -#elif EXP / NWORDS == 31 || (EXP / NWORDS >= 23 && SLOPPY_MAXBPW >= 320) +#elif EXP / NWORDS == 31 || (EXP / NWORDS >= 23 && SLOPPY_MAXBPW >= 3200) i32 w = x; // lowBits(x, bigwordBits = 32); *outCarry = ((i32)(x >> 32) + (w < 0)) << (32 - nBits); return w; From bc09cc8a4c8568b01f5c5e64b5ac7f8d820870a8 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 18 Oct 2025 22:17:47 +0000 Subject: [PATCH 070/115] Split a long -tune output onto two lines. --- src/tune.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index c7939d60..15fc2bc0 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -431,7 +431,8 @@ void Tune::tune() { } log("\n"); - log("Beginning timing of various options. These settings will be appended to config.txt. Please read config.txt after -tune completes.\n"); + log("Beginning timing of various options. These settings will be appended to config.txt.\n"); + log("Please read config.txt after -tune completes.\n"); log("\n"); u32 variant = (defaultShape == &defaultFFTShape) ? 101 : 202; From 3ff265c9df7c6420a182f1b225d95d371af4c15e Mon Sep 17 00:00:00 2001 From: george Date: Sun, 19 Oct 2025 23:23:36 +0000 Subject: [PATCH 071/115] Return correct type from mad32 --- src/cl/math.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 309b1fe5..619cdb47 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -11,7 +11,7 @@ u32 hi32(u64 x) { return (u32) (x >> 32); } // Multiply and add primitives -u128 mad32(u32 a, u32 b, u64 c) { +u64 mad32(u32 a, u32 b, u64 c) { #if 0 && HAS_PTX // Same speed on TitanV, any gain may be too small to measure u32 reslo, reshi; __asm("mad.lo.cc.u32 %0, %1, %2, %3;" : "=r"(reslo) : "r"(a), "r"(b), "r"((u32) c)); From 91554a6426202f57efa043df0a710e01092474d5 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 20 Oct 2025 03:19:52 +0000 Subject: [PATCH 072/115] Fuxed bug where MUL3 could overflow an i32 in the (unsupported) M31-only NTT. --- src/cl/carryfused.cl | 5 ----- src/cl/carryinc.cl | 4 ++-- src/cl/carryutil.cl | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index 22d55ad8..da4904b1 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -461,13 +461,8 @@ KERNEL(G_W) carryFused(P(GF31) out, CP(GF31) in, u32 posROE, P(i64) carryShuttle Word2 wu[NW]; -#if MUL3 - P(i64) carryShuttlePtr = (P(i64)) carryShuttle; - i64 carry[NW+1]; -#else P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; CFcarry carry[NW+1]; -#endif #if AMDGPU #define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index 00bd1976..3563240c 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -77,7 +77,7 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 u, F2 invWeight, iCARRY inCarry, bool // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. -Word2 OVERLOAD weightAndCarryPair(GF31 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { +Word2 OVERLOAD weightAndCarryPair(GF31 u, u32 invWeight1, u32 invWeight2, i32 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); Word a = carryStep(tmp1, &midCarry, b1); @@ -88,7 +88,7 @@ Word2 OVERLOAD weightAndCarryPair(GF31 u, u32 invWeight1, u32 invWeight2, i64 in } // Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. -Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { +Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u, u32 invWeight1, u32 invWeight2, i32 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index d530d33e..2c35d287 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -234,7 +234,7 @@ i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_r #elif FFT_TYPE == FFT31 // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. -i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { +i64 weightAndCarryOne(Z61 u, u32 invWeight, i32 inCarry, u32* maxROE) { // Apply inverse weight u = shr(u, invWeight); @@ -248,7 +248,7 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { // Mul by 3 and add carry #if MUL3 - value *= 3; + return (i64)value * 3 + inCarry; #endif return value + inCarry; } From a9b49042000a583e515efb8ca093ddc7c60629a7 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 20 Oct 2025 17:01:44 +0000 Subject: [PATCH 073/115] Elimiinated i96_mul. Two adds should be at least as fast as a mul by 3. Renamed i96_add and i96_sub for readability. --- src/cl/carryutil.cl | 21 +++++++++------------ src/cl/math.cl | 15 ++++++--------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 2c35d287..d262ed96 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -305,14 +305,13 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, i64 vhi = n64 >> 33; u64 vlo = ((u64)n64 << 31) | n31; i96 value = make_i96(vhi, vlo); // (n64 << 31) + n31 - i96_sub(&value, make_i96(n64)); // n64 * M31 + n31 + value = sub(value, make_i96(n64)); // n64 * M31 + n31 // Mul by 3 and add carry #if MUL3 - i96_mul(&value, 3); + value = add(value, add(value, value)); #endif - i96_add(&value, make_i96(inCarry)); - return value; + return add(value, make_i96(inCarry)); } /**************************************************************************/ @@ -373,14 +372,13 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, i32 vhi = nF2 >> 3; u64 vlo = ((u64)nF2 << 61) | n61; i96 value = make_i96(vhi, vlo); // (nF2 << 61) + n61 - i96_sub(&value, make_i96(nF2)); // nF2 * M61 + n61 + value = sub(value, make_i96(nF2)); // nF2 * M61 + n61 // Mul by 3 and add carry #if MUL3 - i96_mul(&value, 3); + value = add(value, add(value, value)); #endif - i96_add(&value, make_i96(inCarry)); - return value; + return add(value, make_i96(inCarry)); } /**************************************************************************/ @@ -414,14 +412,13 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 i64 vhi = n61 >> 33; u64 vlo = ((u64)n61 << 31) | n31; i96 value = make_i96(vhi, vlo); // (n61 << 31) + n31 - i96_sub(&value, make_i96(n61)); // n61 * M31 + n31 + value = sub(value, make_i96(n61)); // n61 * M31 + n31 // Mul by 3 and add carry #if MUL3 - i96_mul(&value, 3); + value = add(value, add(value, value)); #endif - i96_add(&value, make_i96(inCarry)); - return value; + return add(value, make_i96(inCarry)); } /******************************************************************************/ diff --git a/src/cl/math.cl b/src/cl/math.cl index 619cdb47..61014eac 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -48,9 +48,8 @@ u32 i96_mid32(i96 val) { return val.mid32; } u32 i96_lo32(i96 val) { return val.lo32; } u64 i96_lo64(i96 val) { return ((u64) val.mid32 << 32) | val.lo32; } u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | val.mid32; } -void i96_add(i96 *val, i96 x) { val->lo32 += x.lo32; val->mid32 += x.mid32; val->hi32 += x.hi32 + (val->mid32 < x.mid32); u32 carry = (val->lo32 < x.lo32); val->mid32 += carry; val->hi32 += (val->mid32 < carry); } -void i96_sub(i96 *val, i96 x) { i96 tmp = *val; val->lo32 -= x.lo32; val->mid32 -= x.mid32; val->hi32 -= x.hi32 + (val->mid32 > tmp.mid32); u32 carry = (val->lo32 > tmp.lo32); tmp = *val; val->mid32 -= carry; val->hi32 -= (val->mid32 > tmp.mid32); } -void i96_mul(i96 *val, u32 x) { u64 t = (u64)val->lo32 * x; val->lo32 = (u32)t; t = (u64)val->mid32 * x + (t >> 32); val->mid32 = (u32)t; val->hi32 = val->hi32 * x + (u32)(t >> 32); } +i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.lo32 = a.lo32 + b.lo32; val.mid32 = a.mid32 + b.mid32; val.hi32 = a.hi32 + b.hi32 + (val.mid32 < a.mid32); u32 carry = (val.lo32 < a.lo32); u32 tmp = val.mid32; val.mid32 += carry; val.hi32 += (val.mid32 < tmp); return val; } +i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.lo32 = a.lo32 - b.lo32; val.mid32 = a.mid32 - b.mid32; val.hi32 = a.hi32 - b.hi32 - (val.mid32 > a.mid32); u32 carry = (val.lo32 > a.lo32); u32 tmp = val.mid32; val.mid32 -= carry; val.hi32 -= (val.mid32 > tmp); return val; } #elif 0 // A u64 lo32, u32 hi32 implementation. This too would benefit from add with carry instructions. // On nVidia, the clang optimizer kept the hi32 value as 64-bits! @@ -65,9 +64,8 @@ u32 i96_mid32(i96 val) { return hi32(val.lo64); } u32 i96_lo32(i96 val) { return val.lo64; } u64 i96_lo64(i96 val) { return val.lo64; } u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | i96_mid32(val); } -void i96_add(i96 *val, i96 x) { val->lo64 += x.lo64; val->hi32 += x.hi32 + (val->lo64 < x.lo64); } -void i96_sub(i96 *val, i96 x) { u64 tmp = val->lo64; val->lo64 -= x.lo64; val->hi32 -= x.hi32 + (val->lo64 > tmp); } -void i96_mul(i96 *val, u32 x) { u64 t = i96_lo32(*val) * (u64)x; u32 lo32 = t; t = i96_mid32(*val) * (u64)x + (t >> 32); u32 mid32 = t; u32 hi32 = val->hi32 * x + (t >> 32); *val = make_i96(hi32, mid32, lo32); } +i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.lo64 = a.lo64 + b.lo64; val.hi32 = a.hi32 + b.hi32 + (val.lo64 < a.lo64); return val; } +i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.lo64 = a.lo64 - b.lo64; val.hi32 = a.hi32 - b.hi32 - (val.lo64 > a.lo64); return val; } #elif 1 // An i128 implementation. This might use more GPU registers. nVidia likes this version. typedef struct { i128 x; } i96; @@ -81,9 +79,8 @@ u32 i96_mid32(i96 val) { return (u64)val.x >> 32; } u32 i96_lo32(i96 val) { return val.x; } u64 i96_lo64(i96 val) { return val.x; } u64 i96_hi64(i96 val) { return (u128)val.x >> 32; } -void i96_add(i96 *val, i96 v) { val->x += v.x; } -void i96_sub(i96 *val, i96 v) { val->x -= v.x; } -void i96_mul(i96 *val, u32 x) { val->x *= x; } +i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.x = a.x + b.x; return val; } +i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.x = a.x - b.x; return val; } #endif From 0a108d955bf5c60e1921ba8a6d2e502fcb900470 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 22 Oct 2025 00:40:44 +0000 Subject: [PATCH 074/115] Made a set of routines to support i128 and u128. Needed because Intel compiler does not support __int128 data type. --- src/cl/base.cl | 2 - src/cl/carryutil.cl | 53 +++++++++-------- src/cl/math.cl | 141 +++++++++++++++++++++++++++++--------------- 3 files changed, 120 insertions(+), 76 deletions(-) diff --git a/src/cl/base.cl b/src/cl/base.cl index 2611d281..991c9322 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -171,8 +171,6 @@ typedef int i32; typedef uint u32; typedef long i64; typedef ulong u64; -typedef __int128 i128; -typedef unsigned __int128 u128; // Data types for data stored in FFTs and NTTs during the transform typedef double T; // For historical reasons, classic FFTs using doubles call their data T and T2. diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index d262ed96..b305f499 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -305,13 +305,13 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, i64 vhi = n64 >> 33; u64 vlo = ((u64)n64 << 31) | n31; i96 value = make_i96(vhi, vlo); // (n64 << 31) + n31 - value = sub(value, make_i96(n64)); // n64 * M31 + n31 + value = sub(value, n64); // n64 * M31 + n31 // Mul by 3 and add carry #if MUL3 value = add(value, add(value, value)); #endif - return add(value, make_i96(inCarry)); + return add(value, inCarry); } /**************************************************************************/ @@ -372,13 +372,13 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, i32 vhi = nF2 >> 3; u64 vlo = ((u64)nF2 << 61) | n61; i96 value = make_i96(vhi, vlo); // (nF2 << 61) + n61 - value = sub(value, make_i96(nF2)); // nF2 * M61 + n61 + value = sub(value, nF2); // nF2 * M61 + n61 // Mul by 3 and add carry #if MUL3 value = add(value, add(value, value)); #endif - return add(value, make_i96(inCarry)); + return add(value, inCarry); } /**************************************************************************/ @@ -412,13 +412,13 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 i64 vhi = n61 >> 33; u64 vlo = ((u64)n61 << 31) | n31; i96 value = make_i96(vhi, vlo); // (n61 << 31) + n31 - value = sub(value, make_i96(n61)); // n61 * M31 + n31 + value = sub(value, n61); // n61 * M31 + n31 // Mul by 3 and add carry #if MUL3 value = add(value, add(value, value)); #endif - return add(value, make_i96(inCarry)); + return add(value, inCarry); } /******************************************************************************/ @@ -439,16 +439,20 @@ i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_ u61 = subq(u61, make_Z61(n31), 2); // u61 - u31 u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) u64 n61 = get_Z61(u61); - i128 n3161 = (((i128) n61 << 31) | n31) - n61; // n61 * M31 + n31 + + i128 n3161 = make_i128(n61 >> 33, (n61 << 31) | n31); // n61 << 31 + n31 + n3161 = sub(n3161, n61); // n61 * M31 + n31 // The final result must be n3161 mod M31*M61. Use FP32 data to calculate this value. - float n3161f = (float)((u32)(n61 >> 32)) * -9223372036854775808.0f; // Conversion from i128 to float might be slow, this might be faster + float n3161f = (float)((u32)(n61 >> 32)) * -9223372036854775808.0f; // Converting n3161 from i128 to float might be slow, this might be faster uF2 = fma(uF2, F2_invWeight, n3161f); // This should be close to a multiple of M31*M61 float uF2int = fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL); // Divide by M31*M61 and round to int i32 nF2 = RNDVALfloatToInt(uF2int); - i64 nF2m31 = ((i64) nF2 << 31) - nF2; // nF2 * M31 - i128 v = ((i128) nF2m31 << 61) - nF2m31 + n3161; // nF2m31 * M61 + n3161 + i64 nF2m31 = ((i64)nF2 << 31) - nF2; // nF2 * M31 + i128 v = make_i128(nF2m31 >> 3, (u64)nF2m31 << 61); // nF2m31 << 61 + v = sub(v, nF2m31); // nF2m31 * M61 + v = add(v, n3161); // nF2m31 * M61 + n3161 // Optionally calculate roundoff error float roundoff = fabs(fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL - uF2int)); @@ -456,9 +460,9 @@ i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_ // Mul by 3 and add carry #if MUL3 - v = v * 3; + v = add(v, add(v, v)); #endif - v = v + inCarry; + v = add(v, inCarry); return v; } @@ -473,8 +477,8 @@ error - missing weightAndCarryOne implementation Word OVERLOAD carryStep(i128 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); - i64 w = lowBits((i64)x, nBits); - *outCarry = (u64)(x >> nBits) + (w < 0); + i64 w = lowBits(i128_lo64(x), nBits); + *outCarry = i128_shrlo64(x, nBits) + (w < 0); return w; } @@ -568,11 +572,9 @@ Word OVERLOAD carryStepUnsignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); // Return a Word using the big word size. Big word size is a constant which allows for more optimization. - u64 w = ulowFixedBits((u64)x, bigwordBits); - const i128 topbitmask = ~((i128)1 << (bigwordBits - 1)); -//GW Can we use unsigned shift (knowing the sign won't be lost due to truncating the result) -- this is really a 64-bit extract (or two 32-bit extrats) -- use elsewhere?) - *outCarry = (x & topbitmask) >> nBits; -//GW use this style else where, check for more fixed low bits + u64 w = ulowFixedBits(i128_lo64(x), bigwordBits); + x = i128_masklo64(x, ~((u64)1 << (bigwordBits - 1))); + *outCarry = i128_shrlo64(x, nBits); return w; } @@ -582,7 +584,7 @@ Word OVERLOAD carryStepUnsignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { // Return a Word using the big word size. Big word size is a constant which allows for more optimization. #if EXP / NWORDS >= 32 // nBits is 32 or more - i64 xhi = i96_hi64(x) & ~((1ULL << (bigwordBits - 32)) - 1); + i64 xhi = i96_hi64(x) & ~(((u64)1 << (bigwordBits - 32)) - 1); *outCarry = xhi >> (nBits - 32); return ulowFixedBits(i96_lo64(x), bigwordBits); #elif EXP / NWORDS == 31 // nBits = 31 or 32 @@ -633,15 +635,14 @@ Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { // Return a Word using the big word size. Big word size is a constant which allows for more optimization. const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); - u64 xlo = (u64)x; - u64 xlo_topbit = xlo & (1ULL << (bigwordBits - 1)); + u64 xlo = i128_lo64(x); + u64 xlo_topbit = xlo & ((u64)1 << (bigwordBits - 1)); i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; - *outCarry = (x + xlo_topbit) >> nBits; + *outCarry = i128_shrlo64(add(x, xlo_topbit), nBits); return w; #endif } - Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { #if ACTUAL_BPW > SLOPPY_MAXBPW return carryStep(x, outCarry, isBigWord); @@ -652,7 +653,7 @@ Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); #if EXP / NWORDS >= 32 // nBits is 32 or more u64 xlo = i96_lo64(x); - u64 xlo_topbit = xlo & (1ULL << (bigwordBits - 1)); + u64 xlo_topbit = xlo & ((u64)1 << (bigwordBits - 1)); i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; i64 xhi = i96_hi64(x) + (xlo_topbit >> 32); *outCarry = xhi >> (nBits - 32); @@ -693,7 +694,7 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); #if EXP / NWORDS >= 32 // nBits is 32 or more - u64 x_topbit = x & (1ULL << (bigwordBits - 1)); + u64 x_topbit = x & ((u64)1 << (bigwordBits - 1)); i64 w = ulowFixedBits(x, bigwordBits - 1) - x_topbit; i32 xhi = (i32)(x >> 32) + (i32)(x_topbit >> 32); *outCarry = xhi >> (nBits - 32); diff --git a/src/cl/math.cl b/src/cl/math.cl index 61014eac..8146c441 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -9,36 +9,11 @@ u32 lo32(u64 x) { return (u32) x; } u32 hi32(u64 x) { return (u32) (x >> 32); } -// Multiply and add primitives - -u64 mad32(u32 a, u32 b, u64 c) { -#if 0 && HAS_PTX // Same speed on TitanV, any gain may be too small to measure - u32 reslo, reshi; - __asm("mad.lo.cc.u32 %0, %1, %2, %3;" : "=r"(reslo) : "r"(a), "r"(b), "r"((u32) c)); - __asm("madc.hi.u32 %0, %1, %2, %3;" : "=r"(reshi) : "r"(a), "r"(b), "r"((u32) (c >> 32))); - return ((u64)reshi << 32) | reslo; -#else - return (u64) a * (u64) b + c; -#endif -} - -u128 mad64(u64 a, u64 b, u128 c) { -#if 0 && HAS_PTX // Slower on TitanV, don't understand why - u64 reslo, reshi; - __asm("mad.lo.cc.u64 %0, %1, %2, %3;" : "=l"(reslo) : "l"(a), "l"(b), "l"((u64) c)); - __asm("madc.hi.u64 %0, %1, %2, %3;" : "=l"(reshi) : "l"(a), "l"(b), "l"((u64) (c >> 64))); - return ((u128)reshi << 64) | reslo; -#else - return (u128) a * (u128) b + c; -#endif -} - // A primitive partial implementation of an i96 integer type #if 0 // An all u32 implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions. // This version might be best on AMD and Intel if we can generate add-with-carry instructions. typedef struct { u32 lo32; u32 mid32; u32 hi32; } i96; -i96 OVERLOAD make_i96(i128 v) { i96 val; val.hi32 = (u128)v >> 64, val.mid32 = (u64)v >> 32, val.lo32 = v; return val; } i96 OVERLOAD make_i96(i64 v) { i96 val; val.hi32 = v >> 63, val.mid32 = v >> 32, val.lo32 = v; return val; } i96 OVERLOAD make_i96(i32 v) { i96 val; val.hi32 = v >> 31, val.mid32 = v >> 31, val.lo32 = v; return val; } i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi, val.mid32 = lo >> 32, val.lo32 = lo; return val; } @@ -49,12 +24,31 @@ u32 i96_lo32(i96 val) { return val.lo32; } u64 i96_lo64(i96 val) { return ((u64) val.mid32 << 32) | val.lo32; } u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | val.mid32; } i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.lo32 = a.lo32 + b.lo32; val.mid32 = a.mid32 + b.mid32; val.hi32 = a.hi32 + b.hi32 + (val.mid32 < a.mid32); u32 carry = (val.lo32 < a.lo32); u32 tmp = val.mid32; val.mid32 += carry; val.hi32 += (val.mid32 < tmp); return val; } +i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.lo32 = a.lo32 - b.lo32; val.mid32 = a.mid32 - b.mid32; val.hi32 = a.hi32 - b.hi32 - (val.mid32 > a.mid32); u32 carry = (val.lo32 > a.lo32); u32 tmp = val.mid32; val.mid32 -= carry; val.hi32 -= (val.mid32 > tmp); return val; } -#elif 0 +i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } +#elif defined(__SIZEOF_INT128__) +// An i128 implementation. This might use more GPU registers. nVidia likes this version. +typedef struct { __int128 x; } i96; +i96 OVERLOAD make_i96(i64 v) { i96 val; val.x = v; return val; } +i96 OVERLOAD make_i96(i32 v) { i96 val; val.x = v; return val; } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.x = ((unsigned __int128)hi << 64) + lo; return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { return make_i96((i64)hi, lo); } +u32 i96_hi32(i96 val) { return (unsigned __int128)val.x >> 64; } +u32 i96_mid32(i96 val) { return (u64)val.x >> 32; } +u32 i96_lo32(i96 val) { return val.x; } +u64 i96_lo64(i96 val) { return val.x; } +u64 i96_hi64(i96 val) { return (unsigned __int128)val.x >> 32; } +i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.x = a.x + b.x; return val; } +i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.x = a.x - b.x; return val; } +i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } +#elif 1 // A u64 lo32, u32 hi32 implementation. This too would benefit from add with carry instructions. // On nVidia, the clang optimizer kept the hi32 value as 64-bits! typedef struct { u64 lo64; u32 hi32; } i96; -i96 OVERLOAD make_i96(i128 v) { i96 val; val.hi32 = (u128)v >> 64, val.lo64 = v; return val; } i96 OVERLOAD make_i96(i64 v) { i96 val; val.hi32 = v >> 63, val.lo64 = v; return val; } i96 OVERLOAD make_i96(i32 v) { return make_i96((i64)v); } i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi, val.lo64 = lo; return val; } @@ -65,24 +59,75 @@ u32 i96_lo32(i96 val) { return val.lo64; } u64 i96_lo64(i96 val) { return val.lo64; } u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | i96_mid32(val); } i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.lo64 = a.lo64 + b.lo64; val.hi32 = a.hi32 + b.hi32 + (val.lo64 < a.lo64); return val; } +i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.lo64 = a.lo64 - b.lo64; val.hi32 = a.hi32 - b.hi32 - (val.lo64 > a.lo64); return val; } -#elif 1 -// An i128 implementation. This might use more GPU registers. nVidia likes this version. -typedef struct { i128 x; } i96; -i96 OVERLOAD make_i96(i128 v) { i96 val; val.x = v; return val; } -i96 OVERLOAD make_i96(i64 v) { return make_i96((i128)v); } -i96 OVERLOAD make_i96(i32 v) { return make_i96((i128)v); } -i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.x = ((u128)hi << 64) + lo; return val; } -i96 OVERLOAD make_i96(i32 hi, u64 lo) { return make_i96((i64)hi, lo); } -u32 i96_hi32(i96 val) { return (u128)val.x >> 64; } -u32 i96_mid32(i96 val) { return (u64)val.x >> 32; } -u32 i96_lo32(i96 val) { return val.x; } -u64 i96_lo64(i96 val) { return val.x; } -u64 i96_hi64(i96 val) { return (u128)val.x >> 32; } -i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.x = a.x + b.x; return val; } -i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.x = a.x - b.x; return val; } +i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } +#endif + +// A primitive partial implementation of an i128 and u128 integer type +#if defined(__SIZEOF_INT128__) +typedef struct { __int128 x; } i128; +typedef struct { unsigned __int128 x; } u128; +i128 OVERLOAD make_i128(i64 hi, u64 lo) { i128 val; val.x = ((__int128)hi << 64) | lo; return val; } +i128 OVERLOAD make_i128(u64 hi, u64 lo) { i128 val; val.x = ((__int128)hi << 64) | lo; return val; } +u64 i128_lo64(i128 val) { return val.x; } +u64 i128_shrlo64(i128 val, u32 bits) { return val.x >> bits; } +i128 OVERLOAD i128_masklo64(i128 a, u64 m) { i128 val; val.x = a.x & (((__int128)0xFFFFFFFFFFFFFFFFULL << 64) | m); return val; } +i128 OVERLOAD add(i128 a, i128 b) { i128 val; val.x = a.x + b.x; return val; } +i128 OVERLOAD add(i128 a, u64 b) { i128 val; val.x = a.x + (__int128)b; return val; } +i128 OVERLOAD add(i128 a, i64 b) { i128 val; val.x = a.x + (__int128)b; return val; } +i128 OVERLOAD sub(i128 a, u64 b) { i128 val; val.x = a.x - (__int128)b; return val; } +i128 OVERLOAD sub(i128 a, i64 b) { i128 val; val.x = a.x - (__int128)b; return val; } +u128 OVERLOAD make_u128(u64 hi, u64 lo) { u128 val; val.x = ((unsigned __int128)hi << 64) + lo; return val; } +u64 u128_lo64(u128 val) { return val.x; } +u64 u128_hi64(u128 val) { return val.x >> 64; } +u128 mul64(u64 a, u64 b) { u128 val; val.x = (unsigned __int128)a * (unsigned __int128)b; return val; } +u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.x = a.x + b.x; return val; } +#else // UNTESTED! The mul64 macro causes clang to hang! +typedef struct { i64 hi64; u64 lo64; } i128; +typedef struct { u64 hi64; u64 lo64; } u128; +i128 OVERLOAD make_i128(i64 hi, u64 lo) { i128 val; val.hi64 = hi; val.lo64 = lo; return val; } +i128 OVERLOAD make_i128(u64 hi, u64 lo) { i128 val; val.hi64 = hi; val.lo64 = lo; return val; } +u64 i128_lo64(i128 val) { return val.lo64; } +u64 i128_shrlo64(i128 val, u32 bits) { return (val.hi64 << (64 - bits)) | (val.lo64 >> bits); } +i128 OVERLOAD i128_masklo64(i128 a, u64 m) { i128 val; val.lo64 = a.lo64 & m; val.hi64 = a.hi64; return val; } +i128 OVERLOAD add(i128 a, i128 b) { i128 val; val.lo64 = a.lo64 + b.lo64; val.hi64 = a.hi64 + b.hi64 + (val.lo64 < a.lo64); return val; } +i128 OVERLOAD add(i128 a, u64 b) { i128 val; val.lo64 = a.lo64 + b; val.hi64 = a.hi64 + (val.lo64 < a.lo64); return val; } +i128 OVERLOAD add(i128 a, i64 b) { i128 val; val.lo64 = a.lo64 + b; val.hi64 = a.hi64 + (b >> 63) + (val.lo64 < a.lo64); return val; } +i128 OVERLOAD sub(i128 a, u64 b) { i128 val; val.lo64 = a.lo64 - b; val.hi64 = a.hi64 - (val.lo64 > a.lo64); return val; } +i128 OVERLOAD sub(i128 a, i64 b) { i128 val; val.lo64 = a.lo64 - (u64)b; val.hi64 = a.hi64 - (b >> 63) - (val.lo64 > a.lo64); return val; } +u128 OVERLOAD make_u128(u64 hi, u64 lo) { u128 val; val.hi64 = hi; val.lo64 = lo; return val; } +u64 u128_lo64(u128 val) { return val.lo64; } +u64 u128_hi64(u128 val) { return val.hi64; } +u128 mul64(u64 a, u64 b) { u128 val; val.lo64 = a * b; val.hi64 = mul_hi(a, b); return val; } +u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.lo64 = a.lo64 + b.lo64; val.hi64 = a.hi64 + b.hi64 + (val.lo64 < a.lo64); return val; } #endif +// Multiply and add primitives + +u64 mad32(u32 a, u32 b, u64 c) { +#if 0 && HAS_PTX // Same speed on TitanV, any gain may be too small to measure + u32 reslo, reshi; + __asm("mad.lo.cc.u32 %0, %1, %2, %3;" : "=r"(reslo) : "r"(a), "r"(b), "r"((u32) c)); + __asm("madc.hi.u32 %0, %1, %2, %3;" : "=r"(reshi) : "r"(a), "r"(b), "r"((u32) (c >> 32))); + return ((u64)reshi << 32) | reslo; +#else + return (u64) a * (u64) b + c; +#endif +} + +u128 mad64(u64 a, u64 b, u128 c) { +#if 0 && HAS_PTX // Slower on TitanV, don't understand why + u64 reslo, reshi; + __asm("mad.lo.cc.u64 %0, %1, %2, %3;" : "=l"(reslo) : "l"(a), "l"(b), "l"((u64) c)); + __asm("madc.hi.u64 %0, %1, %2, %3;" : "=l"(reshi) : "l"(a), "l"(b), "l"((u64) (c >> 64))); + return make_u128(reshi, reslo); +#else + return add(mul64(a, b), c); +#endif +} + // The X2 family of macros and SWAP are #defines because OpenCL does not allow pass by reference. // With NTT support added, we need to turn these macros into overloaded routines. @@ -622,8 +667,8 @@ Z61 OVERLOAD shr(Z61 a, u32 k) { return ((a >> k) + (a << (61 - k))) & M61; } GF61 OVERLOAD shr(GF61 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } ulong2 wideMul(u64 ab, u64 cd) { - u128 r = (u128) ab * (u128) cd; - return U2((u64) r, (u64) (r >> 64)); + u128 r = mul64(ab, cd); + return U2(u128_lo64(r), u128_hi64(r)); } Z61 OVERLOAD mul(Z61 a, Z61 b) { @@ -752,8 +797,8 @@ Z61 OVERLOAD shl(Z61 a, u32 k) { return shr(a, 61 - k); } // Return rang GF61 OVERLOAD shl(GF61 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } ulong2 wideMul(u64 ab, u64 cd) { - u128 r = (u128) ab * (u128) cd; - return U2((u64) r, (u64) (r >> 64)); + u128 r = mul64(ab, cd); + return U2(u128_lo64(r), u128_hi64(r)); } // Returns a * b not modded by M61. Max value of result depends on the m61_counts of the inputs. @@ -808,7 +853,7 @@ GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 3-epsilon extra bits in #else Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b_m61_count) { u128 ab = mad64(a, b, c); // Max c value assumed to be M61^2+epsilon - u64 lo = ab, hi = ab >> 64; + u64 lo = u128_lo64(ab), hi = u128_hi64(ab); u64 lo61 = lo & M61; // Max value is M61 if ((a_m61_count - 1) * (b_m61_count - 1) + 1 <= 6) { hi = (hi << 3) + (lo >> 61); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 1) * M61 + epsilon @@ -819,7 +864,7 @@ Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b } } GF61 OVERLOAD cmul(GF61 a, GF61 b) { - u128 k1 = (u128) b.x * (u128) (a.x + a.y); // max value is M61^2+epsilon + u128 k1 = mul64(b.x, a.x + a.y); // max value is M61^2+epsilon Z61 k1k2 = weakMulAdd(a.x, b.y + neg(b.x, 2), k1, 2, 3); // max value is 4*M61+epsilon Z61 k1k3 = weakMulAdd(a.y, neg(b.y + b.x, 3), k1, 2, 4); // max value is 5*M61+epsilon return U2(modM61(k1k3), modM61(k1k2)); From 95ea1283620d4f6655ec043a676036e69bc3d9fa Mon Sep 17 00:00:00 2001 From: george Date: Wed, 22 Oct 2025 00:59:26 +0000 Subject: [PATCH 075/115] Detect nVidia GPUs. Set HAS_PTX. --- src/Gpu.cpp | 1 + src/cl/base.cl | 6 +++++- src/clwrap.cpp | 6 ++++++ src/clwrap.h | 1 + 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 8826925c..2d5518fd 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -304,6 +304,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< }); if (isAmdGpu(id)) { defines += toDefine("AMDGPU", 1); } + if (isNvidiaGpu(id)) { defines += toDefine("NVIDIAGPU", 1); } if ((fft.carry == CARRY_AUTO && fft.shape.needsLargeCarry(E)) || (fft.carry == CARRY_64)) { if (doLog) { log("Using CARRY64\n"); } diff --git a/src/cl/base.cl b/src/cl/base.cl index 991c9322..0ebf0fc3 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -19,6 +19,7 @@ CARRY_LEN NW NH AMDGPU : if this is an AMD GPU +NVIDIAGPU : if this is an nVidia GPU HAS_ASM : set if we believe __asm() can be used for AMD GCN HAS_PTX : set if we believe __asm() can be used for nVidia PTX @@ -63,9 +64,12 @@ G_H "group height" == SMALL_HEIGHT / NH #elif AMDGPU #define HAS_ASM 1 #define HAS_PTX 0 -#else // Assume it is as nVidia GPU (can C code detect nVidia like it does for AMD?) +#elif NVIDIAGPU #define HAS_ASM 0 #define HAS_PTX 1 +#else +#define HAS_ASM 0 +#define HAS_PTX 0 #endif // Default is not adding -2 to results for LL diff --git a/src/clwrap.cpp b/src/clwrap.cpp index 29cdab5d..46e078f0 100644 --- a/src/clwrap.cpp +++ b/src/clwrap.cpp @@ -172,6 +172,12 @@ bool isAmdGpu(cl_device_id id) { return pcieId == 0x1002; } +bool isNvidiaGpu(cl_device_id id) { + u32 pcieId = 0; + GET_INFO(id, CL_DEVICE_VENDOR_ID, pcieId); + return pcieId == 0x10DE; +} + /* static string getFreq(cl_device_id device) { unsigned computeUnits, frequency; diff --git a/src/clwrap.h b/src/clwrap.h index 919f3a4c..e67011b4 100644 --- a/src/clwrap.h +++ b/src/clwrap.h @@ -62,6 +62,7 @@ float getGpuRamGB(cl_device_id id); u64 getFreeMem(cl_device_id id); bool hasFreeMemInfo(cl_device_id id); bool isAmdGpu(cl_device_id id); +bool isNvidiaGpu(cl_device_id id); string getDriverVersion(cl_device_id id); string getDriverVersionByPos(int pos); From bcd3c35dbd8f06962a58f5917b1abfb335d3d79b Mon Sep 17 00:00:00 2001 From: george Date: Wed, 22 Oct 2025 02:26:10 +0000 Subject: [PATCH 076/115] Added some changes for an MSYS build --- src/File.h | 4 ++++ src/Task.cpp | 2 +- src/main.cpp | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/File.h b/src/File.h index 6049246d..8419f9e1 100644 --- a/src/File.h +++ b/src/File.h @@ -60,6 +60,10 @@ class File { _commit(fileno(f)); #elif defined(__APPLE__) fcntl(fileno(f), F_FULLFSYNC, 0); +#elif defined(__MSYS__) +#define fileno(__F) ((__F)->_file) + fsync(fileno(f)); +#undef fileno #else fdatasync(fileno(f)); #endif diff --git a/src/Task.cpp b/src/Task.cpp index 540010fe..676359ec 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -40,7 +40,7 @@ constexpr int platform() { const constexpr bool IS_32BIT = (sizeof(void*) == 4); -#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) || defined(__MINGW32__) || defined(__MINGW64__) +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) || defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) return IS_32BIT ? WIN_32 : WIN_64; #elif __APPLE__ diff --git a/src/main.cpp b/src/main.cpp index 3ded0a9c..a0d63278 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -40,13 +40,13 @@ void gpuWorker(GpuCommon shared, Queue *q, i32 instance) { } -#if defined(__MINGW32__) || defined(__MINGW64__) // for Windows +#if defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) // for Windows extern int putenv(const char *); #endif int main(int argc, char **argv) { -#if defined(__MINGW32__) || defined(__MINGW64__) +#if defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) putenv("ROC_SIGNAL_POOL_SIZE=32"); #else // Required to work around a ROCm bug when using multiple queues From da46795e1cd4a64462b85a603f4e3f93b6ce78aa Mon Sep 17 00:00:00 2001 From: george Date: Wed, 22 Oct 2025 02:37:37 +0000 Subject: [PATCH 077/115] Another fix to get prpll linking under MSYS2 --- src/main.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index a0d63278..b62f1b32 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -41,12 +41,14 @@ void gpuWorker(GpuCommon shared, Queue *q, i32 instance) { #if defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) // for Windows -extern int putenv(const char *); +extern int putenv(char *); #endif int main(int argc, char **argv) { -#if defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) +#if defined(__MSYS__) + // I was unable to get putenv to link in MSYS2 +#elif defined(__MINGW32__) || defined(__MINGW64__) putenv("ROC_SIGNAL_POOL_SIZE=32"); #else // Required to work around a ROCm bug when using multiple queues From 5057c13b6dabdf58c84e93152e3be155844fcf7a Mon Sep 17 00:00:00 2001 From: george Date: Wed, 22 Oct 2025 21:28:35 +0000 Subject: [PATCH 078/115] Wrote a PTX version of mad64 that is faster (on TitanV) than both old PTX version and NO_ASM version. --- src/cl/math.cl | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 8146c441..75dcbc0d 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -79,12 +79,12 @@ i128 OVERLOAD add(i128 a, u64 b) { i128 val; val.x = a.x + (__int128)b; return v i128 OVERLOAD add(i128 a, i64 b) { i128 val; val.x = a.x + (__int128)b; return val; } i128 OVERLOAD sub(i128 a, u64 b) { i128 val; val.x = a.x - (__int128)b; return val; } i128 OVERLOAD sub(i128 a, i64 b) { i128 val; val.x = a.x - (__int128)b; return val; } -u128 OVERLOAD make_u128(u64 hi, u64 lo) { u128 val; val.x = ((unsigned __int128)hi << 64) + lo; return val; } +u128 OVERLOAD make_u128(u64 hi, u64 lo) { u128 val; val.x = ((unsigned __int128)hi << 64) | lo; return val; } u64 u128_lo64(u128 val) { return val.x; } u64 u128_hi64(u128 val) { return val.x >> 64; } u128 mul64(u64 a, u64 b) { u128 val; val.x = (unsigned __int128)a * (unsigned __int128)b; return val; } u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.x = a.x + b.x; return val; } -#else // UNTESTED! The mul64 macro causes clang to hang! +#else // UNTESTED! The mul64 macro causes clang to hang! typedef struct { i64 hi64; u64 lo64; } i128; typedef struct { u64 hi64; u64 lo64; } u128; i128 OVERLOAD make_i128(i64 hi, u64 lo) { i128 val; val.hi64 = hi; val.lo64 = lo; return val; } @@ -107,22 +107,41 @@ u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.lo64 = a.lo64 + b.lo64; val.hi // Multiply and add primitives u64 mad32(u32 a, u32 b, u64 c) { -#if 0 && HAS_PTX // Same speed on TitanV, any gain may be too small to measure +#if HAS_PTX // Same speed on TitanV, any gain may be too small to measure u32 reslo, reshi; - __asm("mad.lo.cc.u32 %0, %1, %2, %3;" : "=r"(reslo) : "r"(a), "r"(b), "r"((u32) c)); - __asm("madc.hi.u32 %0, %1, %2, %3;" : "=r"(reshi) : "r"(a), "r"(b), "r"((u32) (c >> 32))); + __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" + "madc.hi.u32 %1, %2, %3, %5;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"((u32)c), "r"((u32)(c >> 32))); return ((u64)reshi << 32) | reslo; #else - return (u64) a * (u64) b + c; + return (u64)a * (u64)b + c; #endif } u128 mad64(u64 a, u64 b, u128 c) { -#if 0 && HAS_PTX // Slower on TitanV, don't understand why +#if 0 && HAS_PTX // Slower on TitanV and mobile 4070, don't understand why u64 reslo, reshi; - __asm("mad.lo.cc.u64 %0, %1, %2, %3;" : "=l"(reslo) : "l"(a), "l"(b), "l"((u64) c)); - __asm("madc.hi.u64 %0, %1, %2, %3;" : "=l"(reshi) : "l"(a), "l"(b), "l"((u64) (c >> 64))); + __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" + "madc.hi.u64 %1, %2, %3, %5;" : "=l"(reslo), "=l"(reshi) : "l"(a), "l"(b), "l"(u128_lo64(c)), "l"(u128_hi64(c))); return make_u128(reshi, reslo); +#elif HAS_PTX // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. + uint2 a2 = as_uint2(a); + uint2 b2 = as_uint2(b); + uint2 clo2 = as_uint2(u128_lo64(c)); + uint2 chi2 = as_uint2(u128_hi64(c)); + uint2 rlo2, rhi2; + __asm("mad.lo.cc.u32 %0, %4, %6, %8;\n\t" + "madc.hi.cc.u32 %1, %4, %6, %9;\n\t" + "madc.lo.cc.u32 %2, %5, %7, %10;\n\t" + "madc.hi.u32 %3, %5, %7, %11;\n\t" + "mad.lo.cc.u32 %1, %5, %6, %1;\n\t" + "madc.hi.cc.u32 %2, %5, %6, %2;\n\t" + "addc.u32 %3, %3, 0;\n\t" + "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" + "madc.hi.cc.u32 %2, %4, %7, %2;\n\t" + "addc.u32 %3, %3, 0;" + : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "+r"(rhi2.y) + : "r"(a2.x), "r"(a2.y), "r"(b2.x), "r"(b2.y), "r"(clo2.x), "r"(clo2.y), "r"(chi2.x), "r"(chi2.y)); + return make_u128((u64)as_ulong(rhi2), (u64)as_ulong(rlo2)); #else return add(mul64(a, b), c); #endif From 4566c262f2a97e62cff69e5599db86001a4674db Mon Sep 17 00:00:00 2001 From: george Date: Fri, 24 Oct 2025 03:16:11 +0000 Subject: [PATCH 079/115] Implemented i96 as three 32-bit quantities with PTX asm. Reworked carryutil routines by looking at generatd PTX code. Added PTX asm in some cases. Pass boolean to carryOnePair to indicate there i san initial carry. With asm code the compiler cannot optimze away an add of constant zero. --- src/cl/carry.cl | 8 +-- src/cl/carryfused.cl | 8 +-- src/cl/carryinc.cl | 48 +++++++-------- src/cl/carryutil.cl | 144 ++++++++++++++++++++++++------------------- src/cl/math.cl | 54 ++++++++++++---- 5 files changed, 153 insertions(+), 109 deletions(-) diff --git a/src/cl/carry.cl b/src/cl/carry.cl index c1ebbf3a..8741c334 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -300,7 +300,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big // Compute result out[p] = weightAndCarryPair(SWAP_XY(in[p]), SWAP_XY(in31[p]), w1, w2, weight_shift0, weight_shift1, - carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair combo_counter += combo_step; @@ -458,7 +458,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big // Compute result out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in61[p]), w1, w2, weight_shift0, weight_shift1, - carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair combo_counter += combo_step; @@ -546,7 +546,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u // Compute result out[p] = weightAndCarryPair(SWAP_XY(in31[p]), SWAP_XY(in61[p]), m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, - carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_step; @@ -643,7 +643,7 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big // Compute result out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in31[p]), SWAP_XY(in61[p]), w1, w2, m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, - carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_step; diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index da4904b1..dba02e76 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -973,7 +973,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), SWAP_XY(u31[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, // For an LL test, add -2 as the very initial "carry in" // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it - (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair combo_counter += combo_bigstep; @@ -1465,7 +1465,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u61[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, // For an LL test, add -2 as the very initial "carry in" // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it - (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair combo_counter += combo_bigstep; @@ -1722,7 +1722,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( wu[i] = weightAndCarryPairSloppy(SWAP_XY(u31[i]), SWAP_XY(u61[i]), m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, // For an LL test, add -2 as the very initial "carry in" // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it - (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; @@ -2000,7 +2000,7 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u31[i]), SWAP_XY(u61[i]), invWeight1, invWeight2, m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, // For an LL test, add -2 as the very initial "carry in" // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it - (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index 3563240c..459476da 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -138,11 +138,11 @@ Word2 OVERLOAD weightAndCarryPairSloppy(GF61 u, u32 invWeight1, u32 invWeight2, // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. Word2 OVERLOAD weightAndCarryPair(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, - i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i64 midCarry; - i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStep(tmp1, &midCarry, b1); - i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, true, midCarry, maxROE); Word b = carryStep(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -150,11 +150,11 @@ Word2 OVERLOAD weightAndCarryPair(T2 u, GF31 u31, T invWeight1, T invWeight2, u3 // Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, - i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i64 midCarry; - i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); - i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, true, midCarry, maxROE); Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -202,11 +202,11 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, F invWeight1, F invWei // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, - i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i64 midCarry; - i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, inCarry, maxROE); + i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStep(tmp1, &midCarry, b1); - i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, midCarry, maxROE); + i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, true, midCarry, maxROE); Word b = carryStep(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -214,11 +214,11 @@ Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF61 u61, F invWeight1, F invWeight2, // Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, - i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { i64 midCarry; - i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, inCarry, maxROE); + i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); - i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, midCarry, maxROE); + i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, true, midCarry, maxROE); Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -234,11 +234,11 @@ Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWei // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. Word2 OVERLOAD weightAndCarryPair(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, - i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; - i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStep(tmp1, &midCarry, b1); - i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); Word b = carryStep(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -246,11 +246,11 @@ Word2 OVERLOAD weightAndCarryPair(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m3 // Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, - i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { iCARRY midCarry; - i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); - i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -265,11 +265,11 @@ Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, - u32 m61_invWeight1, u32 m61_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + u32 m61_invWeight1, u32 m61_invWeight2, bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; - i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStep(tmp1, &midCarry, b1); - i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); Word b = carryStep(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); @@ -277,11 +277,11 @@ Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F in // Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, - u32 m61_invWeight1, u32 m61_invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + u32 m61_invWeight1, u32 m61_invWeight2, bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; - i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, inCarry, maxROE); + i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); - i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, midCarry, maxROE); + i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); Word b = carryStepSignedSloppy(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index b305f499..9f784a35 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -24,6 +24,8 @@ typedef i32 CarryABM; // Return unsigned low bits (number of bits must be between 1 and 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) u32 OVERLOAD ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } +#elif HAS_PTX +u32 OVERLOAD ulowBits(i32 u, u32 bits) { u32 res; __asm("szext.clamp.u32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else u32 OVERLOAD ulowBits(i32 u, u32 bits) { return (((u32) u << (32 - bits)) >> (32 - bits)); } #endif @@ -33,11 +35,7 @@ u64 OVERLOAD ulowBits(i64 u, u32 bits) { return (((u64) u << (64 - bits)) >> (64 u64 OVERLOAD ulowBits(u64 u, u32 bits) { return ulowBits((i64) u, bits); } // Return unsigned low bits where number of bits is known at compile time (number of bits can be 0 to 32) -#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) -u32 OVERLOAD ulowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return __builtin_amdgcn_ubfe(u, 0, bits); } -#else u32 OVERLOAD ulowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return u & ((1 << bits) - 1); } -#endif u32 OVERLOAD ulowFixedBits(u32 u, const u32 bits) { return ulowFixedBits((i32) u, bits); } // Return unsigned low bits where number of bits is known at compile time (number of bits can be 0 to 64) u64 OVERLOAD ulowFixedBits(i64 u, const u32 bits) { return u & ((1LL << bits) - 1); } @@ -46,32 +44,45 @@ u64 OVERLOAD ulowFixedBits(u64 u, const u32 bits) { return ulowFixedBits((i64) u // Return signed low bits (number of bits must be between 1 and 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) i32 OVERLOAD lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } +#elif HAS_PTX +i32 OVERLOAD lowBits(i32 u, u32 bits) { i32 res; __asm("szext.clamp.s32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else i32 OVERLOAD lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } #endif -i32 OVERLOAD lowBits(u32 u, u32 bits) { return lowBits((i32) u, bits); } +i32 OVERLOAD lowBits(u32 u, u32 bits) { return lowBits((i32)u, bits); } // Return signed low bits (number of bits must be between 1 and 63) i64 OVERLOAD lowBits(i64 u, u32 bits) { return ((u << (64 - bits)) >> (64 - bits)); } -i64 OVERLOAD lowBits(u64 u, u32 bits) { return lowBits((i64) u, bits); } +i64 OVERLOAD lowBits(u64 u, u32 bits) { return lowBits((i64)u, bits); } + +// Return signed low bits (number of bits must be between 1 and 32) +#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) +i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits(u, bits); } +#elif HAS_PTX +i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits(u, bits); } +#else +i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits((u64)u, bits); } +#endif +i32 OVERLOAD lowBitsSafe32(u32 u, u32 bits) { return lowBitsSafe32((i32)u, bits); } // Return signed low bits where number of bits is known at compile time (number of bits can be 0 to 32) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return __builtin_amdgcn_sbfe(u, 0, bits); } +#elif HAS_PTX +i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; i32 res; __asm("szext.clamp.s32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else -// This version should generate 2 shifts i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return (u << (32 - bits)) >> (32 - bits); } -// This version should generate 2 ANDs and one subtract -//i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; if (bits == 1) return -(u & 1); return ulowFixedBits(u, bits - 1) - (u & (1 << bits)); } -i32 OVERLOAD lowFixedBits(u32 u, const u32 bits) { return lowFixedBits((i32) u, bits); } #endif +i32 OVERLOAD lowFixedBits(u32 u, const u32 bits) { return lowFixedBits((i32)u, bits); } // Return signed low bits where number of bits is known at compile time (number of bits can be 1 to 63). The two versions are the same speed on TitanV. i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u << (64 - bits)) >> (64 - bits)); } //i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return (i64) ulowFixedBits(u, bits - 1) - (u & (1LL << (bits - 1))); } -i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64) u, bits); } +i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64)u, bits); } // Extract 32 bits from a 64-bit value #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_alignbit) i32 xtract32(i64 x, u32 bits) { return __builtin_amdgcn_alignbit(as_int2(x).y, as_int2(x).x, bits); } +#elif HAS_PTX +i32 xtract32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" : "=r"(res) : "r"(as_uint2(x).x), "r"(as_uint2(x).y), "r"(bits)); return res; } #else i32 xtract32(i64 x, u32 bits) { return x >> bits; } #endif @@ -286,7 +297,7 @@ i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { #elif FFT_TYPE == FFT6431 // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. -i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, float* maxROE) { +i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, bool hasInCarry, i64 inCarry, float* maxROE) { // Apply inverse weight and get the Z31 data u31 = shr(u31, m31_invWeight); @@ -311,7 +322,8 @@ i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, i64 inCarry, #if MUL3 value = add(value, add(value, value)); #endif - return add(value, inCarry); + if (hasInCarry) value = add(value, inCarry); + return value; } /**************************************************************************/ @@ -352,7 +364,7 @@ i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, #elif FFT_TYPE == FFT3261 // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. -i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { +i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, bool hasInCarry, i64 inCarry, float* maxROE) { // Apply inverse weight and get the Z61 data u61 = shr(u61, m61_invWeight); @@ -378,7 +390,8 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, #if MUL3 value = add(value, add(value, value)); #endif - return add(value, inCarry); + if (hasInCarry) value = add(value, inCarry); + return value; } /**************************************************************************/ @@ -388,7 +401,7 @@ i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, #elif FFT_TYPE == FFT3161 // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. -i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, u32* maxROE) { +i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, bool hasInCarry, i64 inCarry, u32* maxROE) { // Apply inverse weights u31 = shr(u31, m31_invWeight); @@ -418,7 +431,8 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 #if MUL3 value = add(value, add(value, value)); #endif - return add(value, inCarry); + if (hasInCarry) value = add(value, inCarry); + return value; } /******************************************************************************/ @@ -428,7 +442,7 @@ i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, i6 #elif FFT_TYPE == FFT323161 // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. -i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_invWeight, u32 m61_invWeight, i64 inCarry, float* maxROE) { +i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_invWeight, u32 m61_invWeight, bool hasInCarry, i64 inCarry, float* maxROE) { // Apply inverse weights u31 = shr(u31, m31_invWeight); @@ -462,7 +476,7 @@ i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_ #if MUL3 v = add(v, add(v, v)); #endif - v = add(v, inCarry); + if (hasInCarry) v = add(v, inCarry); return v; } @@ -484,25 +498,27 @@ Word OVERLOAD carryStep(i128 x, i64 *outCarry, bool isBigWord) { Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); + u32 nBitsLess32 = bitlen(isBigWord) - 32; // This code can be tricky because we must not shift i32 or u32 variables by 32. #if EXP / NWORDS >= 33 - i64 xhi = i96_hi64(x); - i64 w = lowBits(xhi, nBits - 32); - *outCarry = (xhi - w) >> (nBits - 32); - return (w << 32) | i96_lo32(x); + i32 whi = lowBits(i96_mid32(x), nBitsLess32); + *outCarry = ((i64)i96_hi64(x) - (i64)whi) >> nBitsLess32; + return as_ulong((uint2)(i96_lo32(x), (u32)whi)); #elif EXP / NWORDS == 32 - i64 xhi = i96_hi64(x); - i64 w = lowBits(i96_lo64(x), nBits); - *outCarry = (xhi - (w >> 32)) >> (nBits - 32); - return w; + i32 whi = xtract32(i96_lo64(x), nBitsLess32) >> 31; + *outCarry = ((i64)i96_hi64(x) - (i64)whi) >> nBitsLess32; + return as_ulong((uint2)(i96_lo32(x), (u32)whi)); #elif EXP / NWORDS == 31 - i64 w = lowBits(i96_lo64(x), nBits); - *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); + i32 w = lowBitsSafe32(i96_lo32(x), nBits); + *outCarry = as_long((int2)(xtract32(i96_lo64(x), nBits), xtract32(i96_hi64(x), nBits))) + (w < 0); return w; +// i64 w = lowBits(i96_lo64(x), nBits); +// *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); +// return w; #else i32 w = lowBits(i96_lo32(x), nBits); - *outCarry = ((i96_hi64(x) << (32 - nBits)) | (i96_lo32(x) >> nBits)) + (w < 0); + *outCarry = as_long((int2)(xtract32(i96_lo64(x), nBits), xtract32(i96_hi64(x), nBits))) + (w < 0); return w; #endif } @@ -510,22 +526,22 @@ Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); #if EXP / NWORDS >= 33 - i32 xhi = (x >> 32); + i32 xhi = hi32(x); i32 whi = lowBits(xhi, nBits - 32); *outCarry = (xhi - whi) >> (nBits - 32); - return (Word) (((u64) whi << 32) | (u32)(x)); + return (Word) as_long((int2)(lo32(x), whi)); #elif EXP / NWORDS == 32 - i32 xhi = (x >> 32); + i32 xhi = hi32(x); i64 w = lowBits(x, nBits); - xhi -= w >> 32; + xhi -= (i32)(w >> 32); *outCarry = xhi >> (nBits - 32); return w; #elif EXP / NWORDS == 31 - i64 w = lowBits(x, nBits); + i32 w = lowBitsSafe32(lo32(x), nBits); *outCarry = (x - w) >> nBits; return w; #else - Word w = lowBits((i32) x, nBits); + Word w = lowBits(lo32(x), nBits); *outCarry = (x - w) >> nBits; return w; #endif @@ -534,21 +550,21 @@ Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); #if EXP / NWORDS >= 33 - i32 xhi = (x >> 32); + i32 xhi = hi32(x); i32 w = lowBits(xhi, nBits - 32); *outCarry = (xhi >> (nBits - 32)) + (w < 0); - return (Word) (((u64) w << 32) | (u32)(x)); + return as_long((int2)(lo32(x), w)); #elif EXP / NWORDS == 32 - i32 xhi = (x >> 32); + i32 xhi = hi32(x); i64 w = lowBits(x, nBits); - *outCarry = (i32) (xhi >> (nBits - 32)) + (w < 0); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); return w; #elif EXP / NWORDS == 31 - i32 w = lowBits(x, nBits); - *outCarry = (i32) (x >> nBits) + (w < 0); + i32 w = lowBitsSafe32(lo32(x), nBits); + *outCarry = xtract32(x, nBits) + (w < 0); return w; #else - Word w = lowBits(x, nBits); + i32 w = lowBits(x, nBits); *outCarry = xtract32(x, nBits) + (w < 0); return w; #endif @@ -584,15 +600,15 @@ Word OVERLOAD carryStepUnsignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { // Return a Word using the big word size. Big word size is a constant which allows for more optimization. #if EXP / NWORDS >= 32 // nBits is 32 or more - i64 xhi = i96_hi64(x) & ~(((u64)1 << (bigwordBits - 32)) - 1); + i64 xhi = as_ulong((uint2)(i96_mid32(x) & ~((1 << (bigwordBits - 32)) - 1), i96_hi32(x))); *outCarry = xhi >> (nBits - 32); - return ulowFixedBits(i96_lo64(x), bigwordBits); -#elif EXP / NWORDS == 31 // nBits = 31 or 32 + return as_ulong((uint2)(i96_lo32(x), ulowFixedBits(i96_mid32(x), bigwordBits - 32))); +#elif EXP / NWORDS == 31 || EXP / NWORDS >= 22 // nBits = 31 or 32, fastest version. Should also work on smaller nBits. *outCarry = i96_hi64(x) << (32 - nBits); return i96_lo32(x); // ulowBits(x, bigwordBits = 32); #else // nBits less than 32 u32 w = ulowFixedBits(i96_lo32(x), bigwordBits); - *outCarry = (i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) - w) >> nBits); + *outCarry = as_long((int2)(xtract32(as_long((int2)(i96_lo32(x) - w, i96_mid32(x))), nBits), xtract32(i96_hi64(x), nBits))); return w; #endif } @@ -624,7 +640,7 @@ Word OVERLOAD carryStepUnsignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { // We only allow sloppy results when not near the maximum bits-per-word. For now, this is defined as 1.1 bits below maxbpw. // No studies have been done on reducing this 1,1 value since this is a rather minor optimization. Since the preprocessor can't // handle floats, the MAXBPW value passed in is 100 * maxbpw. -#define SLOPPY_MAXBPW (MAXBPW - 1100) +#define SLOPPY_MAXBPW (MAXBPW - 110) #define ACTUAL_BPW (EXP / (NWORDS / 100)) Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { @@ -632,6 +648,8 @@ Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { return carryStep(x, outCarry, isBigWord); #else +//GW: Need to compare to simple carryStep + // Return a Word using the big word size. Big word size is a constant which allows for more optimization. const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); @@ -652,20 +670,21 @@ Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); #if EXP / NWORDS >= 32 // nBits is 32 or more - u64 xlo = i96_lo64(x); - u64 xlo_topbit = xlo & ((u64)1 << (bigwordBits - 1)); - i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; - i64 xhi = i96_hi64(x) + (xlo_topbit >> 32); - *outCarry = xhi >> (nBits - 32); - return w; + return carryStep(x, outCarry, isBigWord); // Should be just as fast as code below +// u32 xmid_topbit = i96_mid32(x) & (1 << (bigwordBits - 32 - 1)); +// i32 whi = ulowFixedBits(i96_mid32(x), bigwordBits - 32 - 1) - xmid_topbit; +// i64 xhi = i96_hi64(x) + xmid_topbit; +// *outCarry = xhi >> (nBits - 32); +// return as_long((int2)(i96_lo32(x), whi)); #elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 3200 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) i32 w = i96_lo32(x); // lowBits(x, bigwordBits = 32); *outCarry = (i96_hi64(x) + (w < 0)) << (32 - nBits); return w; -#else // nBits less than 32 //GWBUG - is there a faster version? Is this faster than plain old carryStep? - i32 w = lowFixedBits(i96_lo32(x), bigwordBits); - *outCarry = (((i96_hi64(x) << (32 - bigwordBits)) | (i96_lo32(x) >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); - return w; +#else // nBits less than 32 + return carryStep(x, outCarry, isBigWord); // Should be faster than code below +// i32 w = lowFixedBits(i96_lo32(x), bigwordBits); +// *outCarry = (as_long((int2)(xtract32(i96_lo64(x), bigwordBits), xtract32(i96_hi64(x), bigwordBits))) + (w < 0)) << (bigwordBits - nBits); +// return w; #endif #endif } @@ -675,12 +694,7 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { return carryStep(x, outCarry, isBigWord); #else -// GWBUG - not timed to see if it is faster. Highly likely to be slower. -// const u32 bigwordBits = EXP / NWORDS + 1; -// u32 nBits = bitlen(isBigWord); -// u32 w = lowBits(x, bigwordBits); -// *outCarry = (((x << (32 - bigwordBits)) | ((u32) x >> bigwordBits)) + (w < 0)) << (bigwordBits - nBits); -// return w; + // We're unlikely to find code that is better than carryStep return carryStep(x, outCarry, isBigWord); #endif } @@ -690,6 +704,8 @@ Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { return carryStep(x, outCarry, isBigWord); #else +//GW: I need to look at PTX code generated by the code below vs. carryStep + // Return a Word using the big word size. Big word size is a constant which allows for more optimization. const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); diff --git a/src/cl/math.cl b/src/cl/math.cl index 75dcbc0d..aabc33a6 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -6,26 +6,54 @@ // Access parts of a 64-bit value -u32 lo32(u64 x) { return (u32) x; } -u32 hi32(u64 x) { return (u32) (x >> 32); } +u32 OVERLOAD lo32(u64 x) { uint2 x2 = as_uint2(x); return (u32)x2.x; } +u32 OVERLOAD hi32(u64 x) { uint2 x2 = as_uint2(x); return (u32)x2.y; } +u32 OVERLOAD lo32(i64 x) { uint2 x2 = as_uint2(x); return (u32)x2.x; } +i32 OVERLOAD hi32(i64 x) { uint2 x2 = as_uint2(x); return (i32)x2.y; } // A primitive partial implementation of an i96 integer type -#if 0 -// An all u32 implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions. +#if 1 +// An all 32-bit implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions. // This version might be best on AMD and Intel if we can generate add-with-carry instructions. -typedef struct { u32 lo32; u32 mid32; u32 hi32; } i96; -i96 OVERLOAD make_i96(i64 v) { i96 val; val.hi32 = v >> 63, val.mid32 = v >> 32, val.lo32 = v; return val; } -i96 OVERLOAD make_i96(i32 v) { i96 val; val.hi32 = v >> 31, val.mid32 = v >> 31, val.lo32 = v; return val; } -i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi, val.mid32 = lo >> 32, val.lo32 = lo; return val; } -i96 OVERLOAD make_i96(i32 hi, u64 lo) { i96 val; val.hi32 = hi, val.mid32 = lo >> 32, val.lo32 = lo; return val; } +typedef struct { i32 hi32; u32 mid32; u32 lo32; } i96; +i96 OVERLOAD make_i96(i64 v) { i96 val; val.lo32 = lo32(v); val.mid32 = hi32(v); val.hi32 = (i32)val.mid32 >> 31; return val; } +i96 OVERLOAD make_i96(i32 v) { i96 val; val.lo32 = v; val.mid32 = val.hi32 = (i32)val.lo32 >> 31; return val; } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi; val.mid32 = hi32(lo); val.lo32 = lo32(lo); return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { i96 val; val.hi32 = hi; val.mid32 = hi32(lo); val.lo32 = lo32(lo); return val; } u32 i96_hi32(i96 val) { return val.hi32; } u32 i96_mid32(i96 val) { return val.mid32; } u32 i96_lo32(i96 val) { return val.lo32; } -u64 i96_lo64(i96 val) { return ((u64) val.mid32 << 32) | val.lo32; } -u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | val.mid32; } -i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.lo32 = a.lo32 + b.lo32; val.mid32 = a.mid32 + b.mid32; val.hi32 = a.hi32 + b.hi32 + (val.mid32 < a.mid32); u32 carry = (val.lo32 < a.lo32); u32 tmp = val.mid32; val.mid32 += carry; val.hi32 += (val.mid32 < tmp); return val; } +u64 i96_lo64(i96 val) { return as_ulong((uint2)(val.lo32, val.mid32)); } +u64 i96_hi64(i96 val) { return as_ulong((uint2)(val.mid32, val.hi32)); } +i96 OVERLOAD add(i96 a, i96 b) { + i96 val; +#if HAS_PTX + __asm("add.cc.u32 %0, %3, %6;\n\t" + "addc.cc.u32 %1, %4, %7;\n\t" + "addc.u32 %2, %5, %8;" + : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) + : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); +#else + u64 alo64 = as_ulong((uint2)(a.lo32, a.mid32)); u64 blo64 = as_ulong((uint2)(b.lo32, b.mid32)); u64 lo64 = alo64 + blo64; + val.lo32 = lo32(lo64); val.mid32 = hi32(lo64); val.hi32 = a.hi32 + b.hi32 + (lo64 < alo64); +#endif + return val; +} i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } -i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.lo32 = a.lo32 - b.lo32; val.mid32 = a.mid32 - b.mid32; val.hi32 = a.hi32 - b.hi32 - (val.mid32 > a.mid32); u32 carry = (val.lo32 > a.lo32); u32 tmp = val.mid32; val.mid32 -= carry; val.hi32 -= (val.mid32 > tmp); return val; } +i96 OVERLOAD sub(i96 a, i96 b) { + i96 val; +#if HAS_PTX + __asm("sub.cc.u32 %0, %3, %6;\n\t" + "subc.cc.u32 %1, %4, %7;\n\t" + "subc.u32 %2, %5, %8;" + : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) + : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); +#else + u64 alo64 = as_ulong((uint2)(a.lo32, a.mid32)); u64 blo64 = as_ulong((uint2)(b.lo32, b.mid32)); u64 lo64 = alo64 - blo64; + val.lo32 = lo32(lo64); val.mid32 = hi32(lo64); val.hi32 = a.hi32 - b.hi32 - (lo64 > alo64); +#endif + return val; +} i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } #elif defined(__SIZEOF_INT128__) From ec16578dfb8e3f707e3fa0aa8a9facd69bb2b404 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 25 Oct 2025 04:15:26 +0000 Subject: [PATCH 080/115] Ignore first call to setSquareTime. First timings are inaccurate due to substantial startup costs. --- src/Queue.cpp | 9 +++++++-- src/Queue.h | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Queue.cpp b/src/Queue.cpp index 8ef83adc..1dd8f011 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -25,11 +25,12 @@ Queue::Queue(const Context& context, bool profile) : markerQueued(false), queueCount(0), squareTime(50), - squareKernels(4) + squareKernels(4), + firstSetTime(true) { // Formerly a constant (thus the CAPS). nVidia is 3% CPU load at 400 or 500, and 35% load at 800 on my Linux machine. // AMD is just over 2% load at 1600 and 3200 on the same Linux machine. Marginally better timings(?) at 3200. - MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings (if squareKernels = 4) + MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings (if squareKernels = 4) } void Queue::writeTE(cl_mem buf, u64 size, const void* data, TimeInfo* tInfo) { @@ -110,6 +111,10 @@ void Queue::waitForMarkerEvent() { } void Queue::setSquareTime(int time) { + if (firstSetTime) { // Ignore first setSquareTime call. First measured times are wrong because of startup costs + firstSetTime = false; + return; + } if (time < 30) time = 30; // Assume a minimum square time of 30us if (time > 3000) time = 3000; // Assume a maximum square time of 3000us squareTime = time; diff --git a/src/Queue.h b/src/Queue.h index 28aea967..2ea110cc 100644 --- a/src/Queue.h +++ b/src/Queue.h @@ -50,7 +50,7 @@ class Queue : public QueueHolder { void copyBuf(cl_mem src, cl_mem dst, u32 size, TimeInfo* tInfo); void finish(); - void setSquareTime(int); // Set the time to do one squaring (in microseconds) + void setSquareTime(int); // Update the time to do one squaring (in microseconds) void setSquareKernels(int n) { squareKernels = n; } private: // This replaces the "call queue->finish every 400 squarings" code in Gpu.cpp. Solves the busy wait on nVidia GPUs. @@ -60,6 +60,7 @@ class Queue : public QueueHolder { int queueCount; // Count of items added to the queue since last marker int squareTime; // Time to do one squaring (in microseconds) int squareKernels; // Number of kernels in one squaring + bool firstSetTime; // Flag so we can ignore first setSquareTime call (which is inaccurate because of all the initial openCL compiles) void queueMarkerEvent(); // Queue the marker event void waitForMarkerEvent(); // Wait for marker event to complete }; From d83506d398e2be2c623d976ca21810b370428fb6 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 26 Oct 2025 02:02:36 +0000 Subject: [PATCH 081/115] Change maxBpw calculation as only some FFT types support both 32 and 64 bit carries. --- src/FFTConfig.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index f60d5d30..85fd02bb 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -266,7 +266,8 @@ float FFTConfig::maxBpw() const { float b2 = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; b = (b1 + b2) / 2.0; } - return carry == CARRY_32 ? std::min(shape.carry32BPW(), b) : b; + // Only some FFTs support both 32 and 64 bit carries. + return (carry == CARRY_32 && (shape.fft_type == FFT64 || shape.fft_type == FFT3231)) ? std::min(shape.carry32BPW(), b) : b; } FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { From b5690838bdb38c83224b3c616a1585d1b5a008ad Mon Sep 17 00:00:00 2001 From: george Date: Sun, 26 Oct 2025 08:42:19 +0000 Subject: [PATCH 082/115] Allow leadIn/leadOut to be used across a modMul call. A very minor optimization since this only happens once every blockSize squarings. --- src/Gpu.cpp | 59 +++++++++++++++++++++++++++-------------------------- src/Gpu.h | 4 +--- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 2d5518fd..6503ef7b 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -759,13 +759,12 @@ void Gpu::fftHin(Buffer& out, Buffer& in) { if (fft.NTT_GF61) kfftHinGF61(out, in); } -void Gpu::tailSquareZero(Buffer& out, Buffer& in) { - if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquareZero(out, in); - if (fft.NTT_GF31) ktailSquareZeroGF31(out, in); - if (fft.NTT_GF61) ktailSquareZeroGF61(out, in); -} - void Gpu::tailSquare(Buffer& out, Buffer& in) { + if (!tail_single_kernel) { + if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquareZero(out, in); + if (fft.NTT_GF31) ktailSquareZeroGF31(out, in); + if (fft.NTT_GF61) ktailSquareZeroGF61(out, in); + } if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquare(out, in); if (fft.NTT_GF31) ktailSquareGF31(out, in); if (fft.NTT_GF61) ktailSquareGF61(out, in); @@ -937,7 +936,7 @@ Words Gpu::readAndCompress(Buffer& buf) { return compactBits(readChecked( vector Gpu::readCheck() { return readAndCompress(bufCheck); } vector Gpu::readData() { return readAndCompress(bufData); } -// out := inA * inB; +// out := inA * inB; inB is preserved void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { fftP(tmp1, ioA); fftMidIn(tmp2, tmp1); @@ -959,8 +958,15 @@ void Gpu::mul(Buffer& io, Buffer& buf1) { // out := inA * inB; void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { - fftP(buf2, inB); - fftMidIn(buf1, buf2); + modMul(ioA, true, inB, mul3); +}; + +// out := inA * inB; if leadInB set then inB (a.k.a. buf1) is preserved +void Gpu::modMul(Buffer& ioA, bool leadInB, Buffer& inB, bool mul3) { + if (leadInB) { + fftP(buf2, inB); + fftMidIn(buf1, buf2); + } mul(ioA, buf1, buf2, buf3, mul3); }; @@ -1108,13 +1114,6 @@ Words Gpu::expMul(const Words& A, u64 h, const Words& B, bool doSquareB) { static bool testBit(u64 x, int bit) { return x & (u64(1) << bit); } -void Gpu::bottomHalf(Buffer& out, Buffer& inTmp) { - fftMidIn(out, inTmp); - if (!tail_single_kernel) tailSquareZero(inTmp, out); - tailSquare(inTmp, out); - fftMidOut(out, inTmp); -} - // See "left-to-right binary exponentiation" on wikipedia void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3) { if (exp == 0) { @@ -1128,7 +1127,9 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Bu while (!testBit(exp, p)) { --p; } for (--p; ; --p) { - bottomHalf(buf2, buf3); + fftMidIn(buf2, buf3); + tailSquare(buf3, buf2); + fftMidOut(buf2, buf3); if (testBit(exp, p)) { doCarry(buf3, buf2); @@ -1164,9 +1165,13 @@ void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, // LL does not do Mul3 assert(!(doMul3 && doLL)); - if (leadIn) { fftP(buf2, in); } + if (leadIn) { + fftP(buf2, in); + fftMidIn(buf1, buf2); + } - bottomHalf(buf1, buf2); + tailSquare(buf2, buf1); + fftMidOut(buf1, buf2); if (leadOut) { fftW(buf2, buf1); @@ -1188,6 +1193,7 @@ void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, carryFused(buf2, buf1); } // Unused: carryFusedMul(buf2, buf1); + fftMidIn(buf1, buf2); } } @@ -1738,16 +1744,16 @@ PRPResult Gpu::isPrimePRP(const Task& task) { if (skipNextCheckUpdate) { skipNextCheckUpdate = false; } else if (k % blockSize == 0) { - assert(leadIn); - modMul(bufCheck, bufData); + modMul(bufCheck, leadIn, bufData); } ++k; // !! early inc bool doStop = (k % blockSize == 0) && (Signal::stopRequested() || (args.iters && k - startK >= args.iters)); - bool leadOut = (k % blockSize == 0) || k == persistK || k == kEnd || useLongCarry; + bool doCheck = doStop || (k % checkStep == 0) || (k >= kEndEnd) || (k - startK == 2 * blockSize); + bool doLog = k % logStep == 0; + bool leadOut = doCheck || doLog || k == persistK || k == kEnd || useLongCarry; - assert(!doStop || leadOut); if (doStop) { log("Stopping, please wait..\n"); } square(bufData, bufData, leadIn, leadOut, false); @@ -1775,12 +1781,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { log("%s %8d / %d, %s\n", isPrime ? "PP" : "CC", kEnd, E, hex(finalRes64).c_str()); } - bool doCheck = doStop || (k % checkStep == 0) || (k >= kEndEnd) || (k - startK == 2 * blockSize); - bool doLog = k % logStep == 0; - - if (!leadOut || (!doCheck && !doLog)) continue; - - assert(doCheck || doLog); + if (!doCheck && !doLog) continue; u64 res = dataResidue(); float secsPerIt = iterationTimer.reset(k); diff --git a/src/Gpu.h b/src/Gpu.h index 230e8059..3be64aa4 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -222,7 +222,6 @@ class Gpu { void fftMidIn(Buffer& out, Buffer& in); void fftMidOut(Buffer& out, Buffer& in); void fftHin(Buffer& out, Buffer& in); - void tailSquareZero(Buffer& out, Buffer& in); void tailSquare(Buffer& out, Buffer& in); void tailMul(Buffer& out, Buffer& in1, Buffer& in2); void tailMulLow(Buffer& out, Buffer& in1, Buffer& in2); @@ -257,8 +256,6 @@ class Gpu { void exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3); - void bottomHalf(Buffer& out, Buffer& inTmp); - void writeState(u32 k, const vector& check, u32 blockSize); // does either carrryFused() or the expanded version depending on useLongCarry @@ -268,6 +265,7 @@ class Gpu { void mul(Buffer& io, Buffer& inB); void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); + void modMul(Buffer& ioA, bool leadInB, Buffer& inB, bool mul3 = false); fs::path saveProof(const Args& args, const ProofSet& proofSet); std::pair readROE(); From fedb6e4bc95935757a785a2601aea0eab22ee10a Mon Sep 17 00:00:00 2001 From: george Date: Sun, 26 Oct 2025 09:07:32 +0000 Subject: [PATCH 083/115] Further fixes to ignoring first setSquareTime. Previous fix only worked on startup, not on each exponent processed from worktodo.txt. --- src/Queue.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Queue.h b/src/Queue.h index 2ea110cc..4ed4200a 100644 --- a/src/Queue.h +++ b/src/Queue.h @@ -51,7 +51,7 @@ class Queue : public QueueHolder { void finish(); void setSquareTime(int); // Update the time to do one squaring (in microseconds) - void setSquareKernels(int n) { squareKernels = n; } + void setSquareKernels(int n) { squareKernels = n; firstSetTime = true; } private: // This replaces the "call queue->finish every 400 squarings" code in Gpu.cpp. Solves the busy wait on nVidia GPUs. int MAX_QUEUE_COUNT; // Queue size before a marker will be enqueued. Typically, 100 to 1000 squarings. From aa8af680044be959b01a4b3ab89128c21c500b8e Mon Sep 17 00:00:00 2001 From: george Date: Sun, 26 Oct 2025 17:09:26 +0000 Subject: [PATCH 084/115] Corrected comments on max values during a GF61 cmul --- src/cl/math.cl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index aabc33a6..c546be0b 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -898,20 +898,20 @@ GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 3-epsilon extra bits in return U2(modM61(k1 + neg(k3, 4)), modM61(k1 + k2)); } #else -Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b_m61_count) { - u128 ab = mad64(a, b, c); // Max c value assumed to be M61^2+epsilon +Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b_m61_count) { // Max c value assumed to be 2*M61^2+epsilon + u128 ab = mad64(a, b, c); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 2) * M61^2 + epsilon u64 lo = u128_lo64(ab), hi = u128_hi64(ab); u64 lo61 = lo & M61; // Max value is M61 - if ((a_m61_count - 1) * (b_m61_count - 1) + 1 <= 6) { - hi = (hi << 3) + (lo >> 61); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 1) * M61 + epsilon - return lo61 + hi; // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 2) * M61 + epsilon + if ((a_m61_count - 1) * (b_m61_count - 1) + 2 <= 6) { + hi = (hi << 3) + (lo >> 61); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 2) * M61 + epsilon + return lo61 + hi; // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 3) * M61 + epsilon } else { u64 hi61 = ((hi << 3) + (lo >> 61)) & M61; // Max value is M61 return lo61 + hi61 + (hi >> 58); // Max value is 2*M61 + epsilon } } GF61 OVERLOAD cmul(GF61 a, GF61 b) { - u128 k1 = mul64(b.x, a.x + a.y); // max value is M61^2+epsilon + u128 k1 = mul64(b.x, a.x + a.y); // max value is 2*M61^2+epsilon Z61 k1k2 = weakMulAdd(a.x, b.y + neg(b.x, 2), k1, 2, 3); // max value is 4*M61+epsilon Z61 k1k3 = weakMulAdd(a.y, neg(b.y + b.x, 3), k1, 2, 4); // max value is 5*M61+epsilon return U2(modM61(k1k3), modM61(k1k2)); From c0062476dd0f49b87262d8ed374627a770253545 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 26 Oct 2025 17:32:43 +0000 Subject: [PATCH 085/115] More comments corrections --- src/cl/math.cl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index c546be0b..019704ea 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -30,9 +30,9 @@ i96 OVERLOAD add(i96 a, i96 b) { #if HAS_PTX __asm("add.cc.u32 %0, %3, %6;\n\t" "addc.cc.u32 %1, %4, %7;\n\t" - "addc.u32 %2, %5, %8;" - : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) - : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); + "addc.u32 %2, %5, %8;" + : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) + : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); #else u64 alo64 = as_ulong((uint2)(a.lo32, a.mid32)); u64 blo64 = as_ulong((uint2)(b.lo32, b.mid32)); u64 lo64 = alo64 + blo64; val.lo32 = lo32(lo64); val.mid32 = hi32(lo64); val.hi32 = a.hi32 + b.hi32 + (lo64 < alo64); @@ -45,9 +45,9 @@ i96 OVERLOAD sub(i96 a, i96 b) { #if HAS_PTX __asm("sub.cc.u32 %0, %3, %6;\n\t" "subc.cc.u32 %1, %4, %7;\n\t" - "subc.u32 %2, %5, %8;" - : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) - : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); + "subc.u32 %2, %5, %8;" + : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) + : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); #else u64 alo64 = as_ulong((uint2)(a.lo32, a.mid32)); u64 blo64 = as_ulong((uint2)(b.lo32, b.mid32)); u64 lo64 = alo64 - blo64; val.lo32 = lo32(lo64); val.mid32 = hi32(lo64); val.hi32 = a.hi32 - b.hi32 - (lo64 > alo64); @@ -912,8 +912,8 @@ Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b } GF61 OVERLOAD cmul(GF61 a, GF61 b) { u128 k1 = mul64(b.x, a.x + a.y); // max value is 2*M61^2+epsilon - Z61 k1k2 = weakMulAdd(a.x, b.y + neg(b.x, 2), k1, 2, 3); // max value is 4*M61+epsilon - Z61 k1k3 = weakMulAdd(a.y, neg(b.y + b.x, 3), k1, 2, 4); // max value is 5*M61+epsilon + Z61 k1k2 = weakMulAdd(a.x, b.y + neg(b.x, 2), k1, 2, 4); // max value is 6*M61+epsilon + Z61 k1k3 = weakMulAdd(a.y, neg(b.y + b.x, 3), k1, 2, 4); // max value is 6*M61+epsilon return U2(modM61(k1k3), modM61(k1k2)); } #endif From 6fc7e933715ab0f0ac075b65325675e2be934c68 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 27 Oct 2025 02:41:23 +0000 Subject: [PATCH 086/115] Use user specified quick value and exponent to adjust number of iterations to execute in building tune.txt --- src/tune.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/tune.cpp b/src/tune.cpp index 15fc2bc0..81f3be28 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -939,9 +939,10 @@ skip_1K_256 = 0; // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp()); -//GW: If user specified a quick != 7, adjust the formula below??? - quick = (exponent < 50000000) ? 6 : (exponent < 150000000) ? 7 : (exponent < 350000000) ? 8 : 10; - + u32 adjusted_quick = (exponent < 50000000) ? quick - 1 : (exponent < 150000000) ? quick : (exponent < 350000000) ? quick + 1 : quick + 2; + if (adjusted_quick < 1) adjusted_quick = 1; + if (adjusted_quick > 10) adjusted_quick = 10; + // Loop through all possible variants for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { @@ -990,7 +991,7 @@ skip_1K_256 = 0; if (w == 0 && !AMDGPU) continue; if (w == 0 && test.width > 1024) continue; FFTConfig fft{test, variant_WMH (w, 0, 1), CARRY_32}; - cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(quick); + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(adjusted_quick); log("Fast width search %6.1f %12s\n", cost, fft.spec().c_str()); if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_width = w; } } From 81b3f0c659df19ad71a8fab18b4ddfddac1619b2 Mon Sep 17 00:00:00 2001 From: george Date: Mon, 27 Oct 2025 17:44:16 +0000 Subject: [PATCH 087/115] Minor tweak so that M31*M61 NTTs time a few more iterations for wavefront exponents. --- src/tune.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index 81f3be28..b0b73b58 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -939,7 +939,7 @@ skip_1K_256 = 0; // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp()); - u32 adjusted_quick = (exponent < 50000000) ? quick - 1 : (exponent < 150000000) ? quick : (exponent < 350000000) ? quick + 1 : quick + 2; + u32 adjusted_quick = (exponent < 50000000) ? quick - 1 : (exponent < 170000000) ? quick : (exponent < 350000000) ? quick + 1 : quick + 2; if (adjusted_quick < 1) adjusted_quick = 1; if (adjusted_quick > 10) adjusted_quick = 10; From 1a591dacf5d23bedc788c95a2b38eed0f7d779b2 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 30 Oct 2025 21:57:59 +0000 Subject: [PATCH 088/115] Changed the wording in some -tune messages --- src/tune.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tune.cpp b/src/tune.cpp index b0b73b58..8168e591 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -413,18 +413,18 @@ void Tune::tune() { defaultShape = &defaultFFTShape; time_FFTs = 1; if (fp64_time < 0.80 * ntt_time) { - log("FP64 FFTs are significantly faster than NTTs. No NTT tuning will be performed.\n"); + log("FP64 FFTs are significantly faster than integer NTTs. No NTT tuning will be performed.\n"); } else { - log("FP64 FFTs are not significantly faster than NTTs. NTT tuning will be performed.\n"); + log("FP64 FFTs are not significantly faster than integer NTTs. NTT tuning will be performed.\n"); time_NTTs = 1; } } else { defaultShape = &defaultNTTShape; time_NTTs = 1; if (fp64_time > 1.20 * ntt_time) { - log("FP64 FFTs are significantly slower than NTTs. No FP64 tuning will be performed.\n"); + log("FP64 FFTs are significantly slower than integer NTTs. No FP64 tuning will be performed.\n"); } else { - log("FP64 FFTs are not significantly slower than NTTs. FP64 tuning will be performed.\n"); + log("FP64 FFTs are not significantly slower than integer NTTs. FP64 tuning will be performed.\n"); time_FFTs = 1; } } @@ -900,8 +900,8 @@ void Tune::tune() { config.write("\n -log 1000000\n"); } if (args->workers < 2) { - config.write("\n# Running two workers often gives better throughput."); - config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers will often be better."); + config.write("\n# Running two workers sometimes gives better throughput."); + config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers may be better."); config.write("\n# -workers 2 -use TAIL_KERNELS=3\n"); } } From 365f492448c49a6739aeccd32538049dfe899118 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 1 Nov 2025 21:10:30 +0000 Subject: [PATCH 089/115] Reduce default exponent for config tuning FP32. MaxExp depends on variant,TAIL_TRIGS32,TABMUL_CHAIN32 settings. --- src/tune.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tune.cpp b/src/tune.cpp index 8168e591..4993eb1b 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -615,7 +615,7 @@ void Tune::tune() { if (time_NTTs) { FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); - u32 exponent = primes.prevPrime(fft.maxExp()); + u32 exponent = primes.prevPrime(fft.maxBpw() * 0.95 * fft.shape.size()); // Back off the maxExp as different settings will have different maxBpw u32 best_tail_trigs = 0; u32 current_tail_trigs = args->value("TAIL_TRIGS32", 2); double best_cost = -1.0; @@ -694,11 +694,11 @@ void Tune::tune() { args->flags["TABMUL_CHAIN31"] = to_string(best_tabmul_chain); } - // Find best TABMUL_CHAIN61 setting + // Find best TABMUL_CHAIN32 setting if (time_NTTs) { FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); - u32 exponent = primes.prevPrime(fft.maxExp()); + u32 exponent = primes.prevPrime(fft.maxBpw() * 0.95 * fft.shape.size()); // Back off the maxExp as different settings will have different maxBpw u32 best_tabmul_chain = 0; u32 current_tabmul_chain = args->value("TABMUL_CHAIN32", 0); double best_cost = -1.0; From df877846c72feb6e39ace2a663db68c106a895ad Mon Sep 17 00:00:00 2001 From: george Date: Sun, 2 Nov 2025 01:24:47 +0000 Subject: [PATCH 090/115] Fixed AMD asm problem handling BPW between 31 and 32 --- src/cl/carryutil.cl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 9f784a35..9a78290b 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -55,9 +55,7 @@ i64 OVERLOAD lowBits(i64 u, u32 bits) { return ((u << (64 - bits)) >> (64 - bits i64 OVERLOAD lowBits(u64 u, u32 bits) { return lowBits((i64)u, bits); } // Return signed low bits (number of bits must be between 1 and 32) -#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) -i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits(u, bits); } -#elif HAS_PTX +#if HAS_PTX i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits(u, bits); } #else i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits((u64)u, bits); } @@ -78,7 +76,7 @@ i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFix //i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return (i64) ulowFixedBits(u, bits - 1) - (u & (1LL << (bits - 1))); } i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64)u, bits); } -// Extract 32 bits from a 64-bit value +// Extract 32 bits from a 64-bit value (starting bit offset can be 0 to 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_alignbit) i32 xtract32(i64 x, u32 bits) { return __builtin_amdgcn_alignbit(as_int2(x).y, as_int2(x).x, bits); } #elif HAS_PTX @@ -87,6 +85,13 @@ i32 xtract32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" i32 xtract32(i64 x, u32 bits) { return x >> bits; } #endif +// Extract 32 bits from a 64-bit value (starting bit offset can be 0 to 32) +#if HAS_PTX +i32 xtractSafe32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" : "=r"(res) : "r"(as_uint2(x).x), "r"(as_uint2(x).y), "r"(bits)); return res; } +#else +i32 xtractSafe32(i64 x, u32 bits) { return x >> bits; } +#endif + u32 bitlen(bool b) { return EXP / NWORDS + b; } bool test(u32 bits, u32 pos) { return (bits >> pos) & 1; } @@ -511,7 +516,7 @@ Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { return as_ulong((uint2)(i96_lo32(x), (u32)whi)); #elif EXP / NWORDS == 31 i32 w = lowBitsSafe32(i96_lo32(x), nBits); - *outCarry = as_long((int2)(xtract32(i96_lo64(x), nBits), xtract32(i96_hi64(x), nBits))) + (w < 0); + *outCarry = as_long((int2)(xtractSafe32(i96_lo64(x), nBits), xtractSafe32(i96_hi64(x), nBits))) + (w < 0); return w; // i64 w = lowBits(i96_lo64(x), nBits); // *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); @@ -561,7 +566,7 @@ Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { return w; #elif EXP / NWORDS == 31 i32 w = lowBitsSafe32(lo32(x), nBits); - *outCarry = xtract32(x, nBits) + (w < 0); + *outCarry = xtractSafe32(x, nBits) + (w < 0); return w; #else i32 w = lowBits(x, nBits); From 48c9208cd47b437cca9aafe1aa335781f14da3a6 Mon Sep 17 00:00:00 2001 From: george Date: Thu, 6 Nov 2025 21:54:22 +0000 Subject: [PATCH 091/115] Changed from executing smallest exponent in worktodo.txt to requiring command line argument -smallest --- src/Args.cpp | 2 ++ src/Args.h | 1 + src/Worktodo.cpp | 8 ++++---- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 28814d06..4448d8ec 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -156,6 +156,7 @@ named "config.txt" in the prpll run directory. -prp : run a single PRP test and exit, ignoring worktodo.txt -ll : run a single LL test and exit, ignoring worktodo.txt -verify : verify PRP-proof contained in +-smallest : work on smallest exponent in worktodo.txt rather than the first exponent in worktodo.txt -proof : generate proof of power (default: optimal depending on exponent). A lower power reduces disk space requirements but increases the verification cost. A higher power increases disk usage a lot. @@ -367,6 +368,7 @@ void Args::parse(const string& line) { else if (key == "-iters") { iters = stoi(s); assert(iters && (iters % 10000 == 0)); } else if (key == "-prp" || key == "-PRP") { prpExp = stoll(s); } else if (key == "-ll" || key == "-LL") { llExp = stoll(s); } + else if (key == "-smallest") { smallest = true; } else if (key == "-fft") { fftSpec = s; } else if (key == "-dump") { dump = s; } else if (key == "-user") { user = s; } diff --git a/src/Args.h b/src/Args.h index 64823d7b..795cd99c 100644 --- a/src/Args.h +++ b/src/Args.h @@ -62,6 +62,7 @@ class Args { bool verbose = false; bool useCache = false; bool profile = false; + bool smallest = false; fs::path masterDir; fs::path proofResultDir = "proof"; diff --git a/src/Worktodo.cpp b/src/Worktodo.cpp index bc3e4d82..0a981a39 100644 --- a/src/Worktodo.cpp +++ b/src/Worktodo.cpp @@ -93,13 +93,13 @@ std::optional parse(const std::string& line) { } // Among the valid tasks from fileName, return the "best" which means the smallest CERT, or otherwise the exponent PRP/LL -static std::optional bestTask(const fs::path& fileName) { +static std::optional bestTask(const fs::path& fileName, bool smallest) { optional best; for (const string& line : File::openRead(fileName)) { optional task = parse(line); if (task && (!best || (best->kind != Task::CERT && task->kind == Task::CERT) - || ((best->kind != Task::CERT || task->kind == Task::CERT) && task->exponent < best->exponent))) { + || ((best->kind != Task::CERT || task->kind == Task::CERT) && smallest && task->exponent < best->exponent))) { best = task; } } @@ -112,7 +112,7 @@ optional getWork(Args& args, i32 instance) { fs::path localWork = workName(instance); // Try to get a task from the local worktodo- file. - if (optional task = bestTask(localWork)) { return task; } + if (optional task = bestTask(localWork, args.smallest)) { return task; } if (args.masterDir.empty()) { return {}; } @@ -140,7 +140,7 @@ optional getWork(Args& args, i32 instance) { u64 initialSize = fileSize(worktodo); if (!initialSize) { return {}; } - optional task = bestTask(worktodo); + optional task = bestTask(worktodo, args.smallest); if (!task) { return {}; } string workLine = task->line; From 7c3571e13d0ce5d2807daaf87a71d4e02d6afece Mon Sep 17 00:00:00 2001 From: george Date: Fri, 7 Nov 2025 03:48:03 +0000 Subject: [PATCH 092/115] Improved GF31 reduction mod M31 --- src/cl/carry.cl | 18 ++++++------ src/cl/carryfused.cl | 44 +++++++++++++---------------- src/cl/fftp.cl | 16 +++++------ src/cl/math.cl | 67 ++++++++++++++++++++++++++++++++++++++++---- src/cl/weight.cl | 27 ++++++++++++++++++ 5 files changed, 124 insertions(+), 48 deletions(-) diff --git a/src/cl/carry.cl b/src/cl/carry.cl index 8741c334..28863dd8 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -533,11 +533,11 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u // Generate the second weight shifts u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Generate big-word/little-word flags @@ -550,10 +550,9 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(u // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; -// GWBUG - derive m61 weight shifts from m31 counter (or vice versa) sort of easily done from difference in the two weight shifts (no need to add frac_bits twice) + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } carryOut[G_W * g + me] = carry; @@ -630,11 +629,11 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Generate big-word/little-word flags @@ -647,10 +646,9 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, Big // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; -// GWBUG - derive m61 weight shifts from m31 counter (or vice versa) sort of easily done from difference in the two weight shifts (no need to add frac_bits twice) + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } carryOut[G_W * g + me] = carry; diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index dba02e76..05e4ca4c 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -1694,10 +1694,8 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; - m31_weight_shift = m31_weight_shift + log2_NWORDS + 1; - if (m31_weight_shift > 31) m31_weight_shift -= 31; - m61_weight_shift = m61_weight_shift + log2_NWORDS + 1; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift + log2_NWORDS + 1); + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift + log2_NWORDS + 1); // Apply the inverse weights and carry propagate pairs to generate the output carries @@ -1705,11 +1703,11 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate the second weight shifts u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Generate big-word/little-word flags @@ -1726,9 +1724,9 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_bigstep; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } m31_combo_counter = m31_starting_combo_counter; // Restore starting counter for applying weights after carry propagation m61_combo_counter = m61_starting_combo_counter; @@ -1819,11 +1817,11 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate the second weight shifts u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Generate big-word/little-word flag, propagate final carry bool biglit0 = frac_bits <= FRAC_BPW_HI; @@ -1833,9 +1831,9 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_bigstep; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } bar(); @@ -1969,10 +1967,8 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; - m31_weight_shift = m31_weight_shift + log2_NWORDS + 1; - if (m31_weight_shift > 31) m31_weight_shift -= 31; - m61_weight_shift = m61_weight_shift + log2_NWORDS + 1; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift + log2_NWORDS + 1); + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift + log2_NWORDS + 1); // Apply the inverse weights and carry propagate pairs to generate the output carries @@ -1983,11 +1979,11 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Generate big-word/little-word flags @@ -2004,9 +2000,9 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_bigstep; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } m31_combo_counter = m31_starting_combo_counter; // Restore starting counter for applying weights after carry propagation m61_combo_counter = m61_starting_combo_counter; @@ -2104,11 +2100,11 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate the second weight shifts u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Generate big-word/little-word flag, propagate final carry bool biglit0 = frac_bits <= FRAC_BPW_HI; @@ -2119,9 +2115,9 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_bigstep; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } bar(); diff --git a/src/cl/fftp.cl b/src/cl/fftp.cl index d430f984..2fe0d7d4 100644 --- a/src/cl/fftp.cl +++ b/src/cl/fftp.cl @@ -442,11 +442,11 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { // Generate the second weight shifts u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Convert and weight input u31[i] = U2(shl(make_Z31(in[p].x), m31_weight_shift0), shl(make_Z31(in[p].y), m31_weight_shift1)); // Form a GF31 from each pair of input words @@ -454,9 +454,9 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_bigstep; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } fft_WIDTH(lds31, u31, smallTrig31); @@ -533,11 +533,11 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIG F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); u32 m31_weight_shift0 = m31_weight_shift; m31_combo_counter += m31_combo_step; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); u32 m31_weight_shift1 = m31_weight_shift; u32 m61_weight_shift0 = m61_weight_shift; m61_combo_counter += m61_combo_step; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); u32 m61_weight_shift1 = m61_weight_shift; // Convert and weight input uF2[i] = U2(in[p].x * w1, in[p].y * w2); @@ -546,9 +546,9 @@ KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIG // Generate weight shifts and frac_bits for next pair m31_combo_counter += m31_combo_bigstep; - if (m31_weight_shift > 31) m31_weight_shift -= 31; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); m61_combo_counter += m61_combo_bigstep; - if (m61_weight_shift > 61) m61_weight_shift -= 61; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); } fft_WIDTH(ldsF2, uF2, smallTrigF2); diff --git a/src/cl/math.cl b/src/cl/math.cl index 019704ea..8e6304b1 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -132,6 +132,56 @@ u128 mul64(u64 a, u64 b) { u128 val; val.lo64 = a * b; val.hi64 = mul_hi(a, b); u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.lo64 = a.lo64 + b.lo64; val.hi64 = a.hi64 + b.hi64 + (val.lo64 < a.lo64); return val; } #endif +// Select based on sign of first argument. This generates less PTX code, but is no faster on 5xxx GPUs +i32 select32(i32 a, i32 b, i32 c) { +#if HAS_PTX + i32 res; + __asm("slct.s32.s32 %0, %2, %3, %1;" : "=r"(res) : "r"(a), "r"(b), "r"(c)); + return res; +#else + return a >= 0 ? b : c; +#endif +} + +// Optionally add a value if first arg is negative. +i32 optional_add(i32 a, const i32 b) { +#if HAS_PTX + __asm("{.reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 + " @%%p add.s32 %0, %0, %1;}" // if (a < 0) a = a + b + : "+r"(a) : "n"(b)); +#else + if (a < 0) a = a + b; +#endif + return a; +} + +// Optionally subtract a value if first arg is negative. +i32 optional_sub(i32 a, const i32 b) { +#if HAS_PTX + __asm("{.reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 + " @%%p sub.s32 %0, %0, %1;}" // if (a < 0) a = a - b + : "+r"(a) : "n"(b)); +#else + if (a < 0) a = a - b; +#endif + return a; +} + +// Optionally subtract a value if first arg is greater than value. +i32 optional_mod(i32 a, const i32 b) { +#if 0 //HAS_PTX // Not faster on 5xxx GPUs (not sure why) + __asm("{.reg .pred %%p;\n\t" + " setp.ge.s32 %%p, %0, %1;\n\t" // a > b + " @%%p sub.s32 %0, %0, %1;}" // if (a > b) a = a - b + : "+r"(a) : "n"(b)); +#else + if (a >= b) a = a - b; +#endif + return a; +} + // Multiply and add primitives u64 mad32(u32 a, u32 b, u64 c) { @@ -509,12 +559,18 @@ GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } -#elif 1 // This version is a little sloppy. Returns values in 0..M31 range //GWBUG (could this handle M31+1 too> neg() is hard. If so made_Z31(i64) is faster +#elif 1 // This version is a little sloppy. Returns values in 0..M31 range. // Internal routines to return value in 0..M31 range -Z31 OVERLOAD modM31(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) -Z31 OVERLOAD modM31(i32 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) -Z31 OVERLOAD modM31(u64 a) { // a must be less than 0xFFFFFFFF00000000 +//Z31 OVERLOAD modM31(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(Z31 a) { i32 alt = a + 0x80000001; return select32(a, a, alt); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +//Z31 OVERLOAD modM31(Z31 a) { return optional_add(a, 0x80000001); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) + +//Z31 OVERLOAD modM31(i32 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +Z31 OVERLOAD modM31(i32 a) { i32 alt = a - 0x80000001; return select32(a, a, alt); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +//Z31 OVERLOAD modM31(i32 a) { return optional_sub(a, 0x80000001); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) + +Z31 OVERLOAD modM31(u64 a) { // a must be less than 0xFFFFFFFF7FFFFFFF u32 alo = a & M31; u32 amid = (a >> 31) & M31; u32 ahi = a >> 62; @@ -564,8 +620,7 @@ GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } Z31 OVERLOAD shl(Z31 a, u32 k) { return shr(a, 31 - k); } GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } -//Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } //GWBUG. is M31 * M31 a problem???? I think so! needs double mod -Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return modM31(add((Z31) (t & M31), (Z31) (t >> 31))); } //Fixes the M31 * M31 problem +Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return modM31(add((Z31)(t & M31), (Z31)(t >> 31))); } Z31 OVERLOAD fma(Z31 a, Z31 b, Z31 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? diff --git a/src/cl/weight.cl b/src/cl/weight.cl index 62e65d96..66075bb7 100644 --- a/src/cl/weight.cl +++ b/src/cl/weight.cl @@ -133,3 +133,30 @@ F optionalHalve(F w) { // return w >= 4 ? w / 2 : w; } #endif + + +/**************************************************************************/ +/* Helper routines for NTT weight calculations */ +/**************************************************************************/ + +#if NTT_GF31 + +// if (weight_shift > 31) weight_shift -= 31; +// This version uses PTX instructions which may be faster on nVidia GPUs +u32 adjust_m31_weight_shift (u32 weight_shift) { + return optional_mod(weight_shift, 31); +} + +#endif + + +#if NTT_GF61 + +// if (weight_shift > 61) weight_shift -= 61; +// This version uses PTX instructions which may be faster on nVidia GPUs +u32 adjust_m61_weight_shift (u32 weight_shift) { + return optional_mod(weight_shift, 61); +} + +#endif + From 6977bba64aacbd9b1c1d1c29ad972e60cc3d0bfd Mon Sep 17 00:00:00 2001 From: george Date: Fri, 7 Nov 2025 18:54:50 +0000 Subject: [PATCH 093/115] More mad32 --- src/cl/math.cl | 86 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 25 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 8e6304b1..3d6df82c 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -184,18 +184,58 @@ i32 optional_mod(i32 a, const i32 b) { // Multiply and add primitives -u64 mad32(u32 a, u32 b, u64 c) { +u64 OVERLOAD mad32(u32 a, u32 b, u32 c) { #if HAS_PTX // Same speed on TitanV, any gain may be too small to measure u32 reslo, reshi; __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" - "madc.hi.u32 %1, %2, %3, %5;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"((u32)c), "r"((u32)(c >> 32))); - return ((u64)reshi << 32) | reslo; + "madc.hi.u32 %1, %2, %3, 0;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"(c)); + return as_ulong((uint2)(reslo, reshi)); #else return (u64)a * (u64)b + c; #endif } -u128 mad64(u64 a, u64 b, u128 c) { +u64 OVERLOAD mad32(u32 a, u32 b, u64 c) { +#if HAS_PTX // Same speed on TitanV, any gain may be too small to measure + u32 reslo, reshi; + __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" + "madc.hi.u32 %1, %2, %3, %5;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"(lo32(c)), "r"(hi32(c))); + return as_ulong((uint2)(reslo, reshi)); +#else + return (u64)a * (u64)b + c; +#endif +} + +u128 OVERLOAD mad64(u64 a, u64 b, u64 c) { +#if 0 && HAS_PTX // Slower on TitanV and mobile 4070, don't understand why + u64 reslo, reshi; + __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" + "madc.hi.u64 %1, %2, %3, 0;" : "=l"(reslo), "=l"(reshi) : "l"(a), "l"(b), "l"(u128_lo64(c))); + return make_u128(reshi, reslo); +#elif HAS_PTX // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. + uint2 a2 = as_uint2(a); + uint2 b2 = as_uint2(b); + uint2 c2 = as_uint2(c); + uint2 rlo2, rhi2; + __asm("mad.lo.cc.u32 %0, %4, %6, %8;\n\t" + "madc.hi.cc.u32 %1, %4, %6, %9;\n\t" + "madc.lo.cc.u32 %2, %5, %7, 0;\n\t" + "madc.hi.u32 %3, %5, %7, 0;\n\t" + "mad.lo.cc.u32 %1, %5, %6, %1;\n\t" + "madc.hi.cc.u32 %2, %5, %6, %2;\n\t" + "addc.u32 %3, %3, 0;\n\t" + "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" + "madc.hi.cc.u32 %2, %4, %7, %2;\n\t" + "addc.u32 %3, %3, 0;" + : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "=r"(rhi2.y) + : "r"(a2.x), "r"(a2.y), "r"(b2.x), "r"(b2.y), "r"(c2.x), "r"(c2.y)); + return make_u128((u64)as_ulong(rhi2), (u64)as_ulong(rlo2)); +#else + return add(mul64(a, b), c); +#endif +} + +u128 OVERLOAD mad64(u64 a, u64 b, u128 c) { #if 0 && HAS_PTX // Slower on TitanV and mobile 4070, don't understand why u64 reslo, reshi; __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" @@ -217,7 +257,7 @@ u128 mad64(u64 a, u64 b, u128 c) { "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" "madc.hi.cc.u32 %2, %4, %7, %2;\n\t" "addc.u32 %3, %3, 0;" - : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "+r"(rhi2.y) + : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "=r"(rhi2.y) : "r"(a2.x), "r"(a2.y), "r"(b2.x), "r"(b2.y), "r"(clo2.x), "r"(clo2.y), "r"(chi2.x), "r"(chi2.y)); return make_u128((u64)as_ulong(rhi2), (u64)as_ulong(rlo2)); #else @@ -498,10 +538,8 @@ GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } -Z31 OVERLOAD fma(Z31 a, Z31 b, Z31 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? - // Multiply by 2 -Z31 OVERLOAD mul2(Z31 a) { return ((a + a) + (a >> 30)) & M31; } // GWBUG: Can we do better? +Z31 OVERLOAD mul2(Z31 a) { return add(a, a); } GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } // Return conjugate of a @@ -622,8 +660,6 @@ GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return modM31(add((Z31)(t & M31), (Z31)(t >> 31))); } -Z31 OVERLOAD fma(Z31 a, Z31 b, Z31 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? - // Multiply by 2 Z31 OVERLOAD mul2(Z31 a) { return add(a, a); } GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } @@ -633,37 +669,37 @@ GF31 OVERLOAD conjugate(GF31 a) { return U2(a.x, neg(a.y)); } // Complex square. input, output 31 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). GF31 OVERLOAD csq(GF31 a) { - u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) - u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 return U2(modM31(r), modM31(i)); } // a^2 + c GF31 OVERLOAD csq_add(GF31 a, GF31 c) { - u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) - u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 - return U2(modM31(r + c.x), modM31(i + c.y)); // GWBUG - hopefully the 64-bit adds are "free" via MAD instructions + u64 r = mad32(a.x + a.y, a.x + neg(a.y), c.x); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, c.y); // 63-bit value, mul max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); } // a^2 - c GF31 OVERLOAD csq_sub(GF31 a, GF31 c) { - u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) - u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 - return U2(modM31(r + neg(c.x)), modM31((i64) i - c.y)); // GWBUG - check that the compiler generates MAD instructions + u64 r = mad32(a.x + a.y, a.x + neg(a.y), neg(c.x)); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, neg(c.y)); // 63-bit value, mul max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); } // a^2 + i*c GF31 OVERLOAD csq_addi(GF31 a, GF31 c) { - u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) - u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 - return U2(modM31(r + neg(c.y)), modM31(i + c.x)); // GWBUG - hopefully the 64-bit adds are "free" via MAD instructions + u64 r = mad32(a.x + a.y, a.x + neg(a.y), neg(c.y)); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, c.x); // 63-bit value, mul max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); } // a^2 - i*c GF31 OVERLOAD csq_subi(GF31 a, GF31 c) { - u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) - u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 - return U2(modM31(r + c.y), modM31((i64) i - c.x)); // GWBUG - check that the compiler generates MAD instructions + u64 r = mad32(a.x + a.y, a.x + neg(a.y), c.y); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, neg(c.x)); // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); } // Complex mul @@ -695,7 +731,7 @@ GF31 OVERLOAD ccubeTrig(GF31 sq, GF31 w) { Z31 tmp = sq.y + sq.y; return U2(modM GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? // mul with (2^15, 2^15). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^15)^2 == 1 (mod M31). -GF31 OVERLOAD mul_t8(GF31 a) { return U2(shl(sub(a.x, a.y), 15), shl(add(a.x, a.y), 15)); } // GWBUG: Can caller use a version that does not negate real? is shl(neg) same as shr??? +GF31 OVERLOAD mul_t8(GF31 a) { return U2(shl(sub(a.x, a.y), 15), shl(add(a.x, a.y), 15)); } // mul with (-2^15, 2^15). (twiddle of 3*tau/8). GF31 OVERLOAD mul_3t8(GF31 a) { return U2(shl(neg(add(a.x, a.y)), 15), shl(sub(a.x, a.y), 15)); } From f25eaceaa812824670047a5951cc12d6f07d4599 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 7 Nov 2025 23:54:30 +0000 Subject: [PATCH 094/115] Improved cache locality for M31+M61 NTTs. Helpful on machines with a good size L2 cache. --- src/Gpu.cpp | 200 +++++++++++++++++++++++++++++----------------------- src/Gpu.h | 23 +++--- 2 files changed, 125 insertions(+), 98 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 6503ef7b..55882872 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -741,16 +741,16 @@ void Gpu::fftW(Buffer& out, Buffer& in) { if (fft.NTT_GF61) kfftWGF61(out, in); } -void Gpu::fftMidIn(Buffer& out, Buffer& in) { - if (fft.FFT_FP64 || fft.FFT_FP32) kfftMidIn(out, in); - if (fft.NTT_GF31) kfftMidInGF31(out, in); - if (fft.NTT_GF61) kfftMidInGF61(out, in); +void Gpu::fftMidIn(Buffer& out, Buffer& in, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) kfftMidIn(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) kfftMidInGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) kfftMidInGF61(out, in); } -void Gpu::fftMidOut(Buffer& out, Buffer& in) { - if (fft.FFT_FP64 || fft.FFT_FP32) kfftMidOut(out, in); - if (fft.NTT_GF31) kfftMidOutGF31(out, in); - if (fft.NTT_GF61) kfftMidOutGF61(out, in); +void Gpu::fftMidOut(Buffer& out, Buffer& in, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) kfftMidOut(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) kfftMidOutGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) kfftMidOutGF61(out, in); } void Gpu::fftHin(Buffer& out, Buffer& in) { @@ -759,27 +759,27 @@ void Gpu::fftHin(Buffer& out, Buffer& in) { if (fft.NTT_GF61) kfftHinGF61(out, in); } -void Gpu::tailSquare(Buffer& out, Buffer& in) { +void Gpu::tailSquare(Buffer& out, Buffer& in, int cache_group) { if (!tail_single_kernel) { - if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquareZero(out, in); - if (fft.NTT_GF31) ktailSquareZeroGF31(out, in); - if (fft.NTT_GF61) ktailSquareZeroGF61(out, in); + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailSquareZero(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailSquareZeroGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailSquareZeroGF61(out, in); } - if (fft.FFT_FP64 || fft.FFT_FP32) ktailSquare(out, in); - if (fft.NTT_GF31) ktailSquareGF31(out, in); - if (fft.NTT_GF61) ktailSquareGF61(out, in); + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailSquare(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailSquareGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailSquareGF61(out, in); } -void Gpu::tailMul(Buffer& out, Buffer& in1, Buffer& in2) { - if (fft.FFT_FP64 || fft.FFT_FP32) ktailMul(out, in1, in2); - if (fft.NTT_GF31) ktailMulGF31(out, in1, in2); - if (fft.NTT_GF61) ktailMulGF61(out, in1, in2); +void Gpu::tailMul(Buffer& out, Buffer& in1, Buffer& in2, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailMul(out, in1, in2); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailMulGF31(out, in1, in2); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailMulGF61(out, in1, in2); } -void Gpu::tailMulLow(Buffer& out, Buffer& in1, Buffer& in2) { - if (fft.FFT_FP64 || fft.FFT_FP32) ktailMulLow(out, in1, in2); - if (fft.NTT_GF31) ktailMulLowGF31(out, in1, in2); - if (fft.NTT_GF61) ktailMulLowGF61(out, in1, in2); +void Gpu::tailMulLow(Buffer& out, Buffer& in1, Buffer& in2, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailMulLow(out, in1, in2); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailMulLowGF31(out, in1, in2); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailMulLowGF61(out, in1, in2); } void Gpu::carryA(Buffer& out, Buffer& in) { @@ -939,13 +939,16 @@ vector Gpu::readData() { return readAndCompress(bufData); } // out := inA * inB; inB is preserved void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { fftP(tmp1, ioA); - fftMidIn(tmp2, tmp1); - tailMul(tmp1, inB, tmp2); + + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + fftMidIn(tmp2, tmp1, cache_group); + tailMul(tmp1, inB, tmp2, cache_group); + fftMidOut(tmp2, tmp1, cache_group); + } // Register the current ROE pos as multiplication (vs. a squaring) if (mulRoePos.empty() || mulRoePos.back() < roePos) { mulRoePos.push_back(roePos); } - fftMidOut(tmp2, tmp1); fftW(tmp1, tmp2); if (mul3) { carryM(ioA, tmp1); } else { carryA(ioA, tmp1); } carryB(ioA); @@ -958,15 +961,13 @@ void Gpu::mul(Buffer& io, Buffer& buf1) { // out := inA * inB; void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { - modMul(ioA, true, inB, mul3); + modMul(ioA, LEAD_NONE, inB, mul3); }; -// out := inA * inB; if leadInB set then inB (a.k.a. buf1) is preserved -void Gpu::modMul(Buffer& ioA, bool leadInB, Buffer& inB, bool mul3) { - if (leadInB) { - fftP(buf2, inB); - fftMidIn(buf1, buf2); - } +// out := inA * inB; inB will end up in buf1 in the LEAD_MIDDLE state +void Gpu::modMul(Buffer& ioA, enum LEAD_TYPE leadInB, Buffer& inB, bool mul3) { + if (leadInB == LEAD_NONE) fftP(buf2, inB); + if (leadInB != LEAD_MIDDLE) fftMidIn(buf1, buf2); mul(ioA, buf1, buf2, buf3, mul3); }; @@ -1127,15 +1128,19 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Bu while (!testBit(exp, p)) { --p; } for (--p; ; --p) { - fftMidIn(buf2, buf3); - tailSquare(buf3, buf2); - fftMidOut(buf2, buf3); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + fftMidIn(buf2, buf3); + tailSquare(buf3, buf2); + fftMidOut(buf2, buf3); + } if (testBit(exp, p)) { doCarry(buf3, buf2); - fftMidIn(buf2, buf3); - tailMulLow(buf3, buf2, buf1); - fftMidOut(buf2, buf3); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + fftMidIn(buf2, buf3, cache_group); + tailMulLow(buf3, buf2, buf1, cache_group); + fftMidOut(buf2, buf3, cache_group); + } } if (!p) { break; } @@ -1161,19 +1166,22 @@ void Gpu::doCarry(Buffer& out, Buffer& in) { } } -void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3, bool doLL) { +void Gpu::square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut, bool doMul3, bool doLL) { + // leadOut = LEAD_MIDDLE is not supported (slower than LEAD_WIDTH) + assert(leadOut != LEAD_MIDDLE); // LL does not do Mul3 assert(!(doMul3 && doLL)); - if (leadIn) { - fftP(buf2, in); - fftMidIn(buf1, buf2); - } + if (leadIn == LEAD_NONE) fftP(buf2, in); - tailSquare(buf2, buf1); - fftMidOut(buf1, buf2); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf2, cache_group); + tailSquare(buf2, buf1, cache_group); + fftMidOut(buf1, buf2, cache_group); + } - if (leadOut) { + // If leadOut is not allowed then we cannot use the faster carryFused kernel + if (leadOut == LEAD_NONE) { fftW(buf2, buf1); if (!doLL && !doMul3) { carryA(out, buf2); @@ -1183,27 +1191,25 @@ void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, carryM(out, buf2); } carryB(out); - } else { + } + + // Use CarryFused + else { assert(!useLongCarry); assert(!doMul3); - if (doLL) { carryFusedLL(buf2, buf1); } else { carryFused(buf2, buf1); } - // Unused: carryFusedMul(buf2, buf1); - fftMidIn(buf1, buf2); } } -void Gpu::square(Buffer& io) { square(io, io, true, true, false, false); } - u32 Gpu::squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3) { assert(from < to); - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; for (u32 k = from; k < to; ++k) { - bool leadOut = useLongCarry || (k == to - 1); + enum LEAD_TYPE leadOut = useLongCarry || (k == to - 1) ? LEAD_NONE : LEAD_WIDTH; square(out, (k==from) ? in : out, leadIn, leadOut, doTailMul3 && (k == to - 1)); leadIn = leadOut; } @@ -1501,12 +1507,17 @@ tuple Gpu::measureCarry() { assert(res == state.res64); } - modMul(bufCheck, bufData); - square(bufData, bufData, true, useLongCarry); + enum LEAD_TYPE leadIn = LEAD_NONE; + modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; + + enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; while (k < warmup) { - square(bufData, bufData, useLongCarry, useLongCarry); + square(bufData, bufData, leadIn, leadOut); ++k; } @@ -1514,20 +1525,20 @@ tuple Gpu::measureCarry() { if (Signal::stopRequested()) { throw "stop requested"; } - bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { - square(bufData, bufData, leadIn, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; - leadIn = useLongCarry; } - square(bufData, bufData, useLongCarry, true); - leadIn = true; + square(bufData, bufData, leadIn, LEAD_NONE); + leadIn = LEAD_NONE; ++k; if (k >= iters) { break; } - modMul(bufCheck, bufData); + modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1569,12 +1580,18 @@ tuple Gpu::measureROE(bool quick) { assert(res == state.res64); } - modMul(bufCheck, bufData); - square(bufData, bufData, true, useLongCarry); + enum LEAD_TYPE leadIn = LEAD_NONE; + modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; + + enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; while (k < warmup) { - square(bufData, bufData, useLongCarry, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; } @@ -1582,20 +1599,20 @@ tuple Gpu::measureROE(bool quick) { if (Signal::stopRequested()) { throw "stop requested"; } - bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { - square(bufData, bufData, leadIn, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; - leadIn = useLongCarry; } - square(bufData, bufData, useLongCarry, true); - leadIn = true; + square(bufData, bufData, leadIn, LEAD_NONE); + leadIn = LEAD_NONE; ++k; if (k >= iters) { break; } - modMul(bufCheck, bufData); + modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1632,12 +1649,18 @@ double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest writeState(state.k, state.check, state.blockSize); assert(dataResidue() == state.res64); - modMul(bufCheck, bufData); - square(bufData, bufData, true, useLongCarry); + enum LEAD_TYPE leadIn = LEAD_NONE; + modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; + + enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; while (k < warmup) { - square(bufData, bufData, useLongCarry, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; } queue->finish(); @@ -1645,20 +1668,20 @@ double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest Timer t; queue->setSquareTime(0); // Busy wait on nVidia to get the most accurate timings while tuning - bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { - square(bufData, bufData, leadIn, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; - leadIn = useLongCarry; } - square(bufData, bufData, useLongCarry, true); - leadIn = true; + square(bufData, bufData, leadIn, LEAD_NONE); + leadIn = LEAD_NONE; ++k; if (k >= iters) { break; } - modMul(bufCheck, bufData); + modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } queue->finish(); @@ -1726,7 +1749,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { bool skipNextCheckUpdate = false; u32 persistK = proofSet.next(k); - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; assert(k % blockSize == 0); assert(checkStep % blockSize == 0); @@ -1745,6 +1768,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { skipNextCheckUpdate = false; } else if (k % blockSize == 0) { modMul(bufCheck, leadIn, bufData); + leadIn = LEAD_MIDDLE; } ++k; // !! early inc @@ -1752,7 +1776,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { bool doStop = (k % blockSize == 0) && (Signal::stopRequested() || (args.iters && k - startK >= args.iters)); bool doCheck = doStop || (k % checkStep == 0) || (k >= kEndEnd) || (k - startK == 2 * blockSize); bool doLog = k % logStep == 0; - bool leadOut = doCheck || doLog || k == persistK || k == kEnd || useLongCarry; + enum LEAD_TYPE leadOut = doCheck || doLog || k == persistK || k == kEnd || useLongCarry ? LEAD_NONE : LEAD_WIDTH; if (doStop) { log("Stopping, please wait..\n"); } @@ -1886,7 +1910,7 @@ LLResult Gpu::isPrimeLL(const Task& task) { u32 k = startK; u32 kEnd = E - 2; - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; while (true) { ++k; @@ -1898,7 +1922,7 @@ LLResult Gpu::isPrimeLL(const Task& task) { } bool doLog = (k % args.logStep == 0) || doStop; - bool leadOut = doLog || useLongCarry; + enum LEAD_TYPE leadOut = doLog || useLongCarry ? LEAD_NONE : LEAD_WIDTH; squareLL(bufData, leadIn, leadOut); leadIn = leadOut; @@ -1961,7 +1985,7 @@ array Gpu::isCERT(const Task& task) { u32 k = 0; u32 kEnd = task.squarings; - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; while (true) { ++k; @@ -1973,7 +1997,7 @@ array Gpu::isCERT(const Task& task) { } bool doLog = (k % 100'000 == 0) || doStop; - bool leadOut = doLog || useLongCarry; + enum LEAD_TYPE leadOut = doLog || useLongCarry ? LEAD_NONE : LEAD_WIDTH; squareCERT(bufData, leadIn, leadOut); leadIn = leadOut; diff --git a/src/Gpu.h b/src/Gpu.h index 3be64aa4..20bc533f 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -217,14 +217,16 @@ class Gpu { TimeInfo* timeBufVect; ZAvg zAvg; + int NUM_CACHE_GROUPS = 3; + void fftP(Buffer& out, Buffer& in) { fftP(out, reinterpret_cast&>(in)); } void fftP(Buffer& out, Buffer& in); - void fftMidIn(Buffer& out, Buffer& in); - void fftMidOut(Buffer& out, Buffer& in); + void fftMidIn(Buffer& out, Buffer& in, int cache_group = 0); + void fftMidOut(Buffer& out, Buffer& in, int cache_group = 0); void fftHin(Buffer& out, Buffer& in); - void tailSquare(Buffer& out, Buffer& in); - void tailMul(Buffer& out, Buffer& in1, Buffer& in2); - void tailMulLow(Buffer& out, Buffer& in1, Buffer& in2); + void tailSquare(Buffer& out, Buffer& in, int cache_group = 0); + void tailMul(Buffer& out, Buffer& in1, Buffer& in2, int cache_group = 0); + void tailMulLow(Buffer& out, Buffer& in1, Buffer& in2, int cache_group = 0); void fftW(Buffer& out, Buffer& in); void carryA(Buffer& out, Buffer& in) { carryA(reinterpret_cast&>(out), in); } void carryA(Buffer& out, Buffer& in); @@ -240,11 +242,12 @@ class Gpu { vector readOut(Buffer &buf); void writeIn(Buffer& buf, vector&& words); - void square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3 = false, bool doLL = false); - void squareCERT(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, false); } - void squareLL(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, true); } + enum LEAD_TYPE {LEAD_NONE = 0, LEAD_WIDTH = 1, LEAD_MIDDLE = 2}; - void square(Buffer& io); + void square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut, bool doMul3 = false, bool doLL = false); + void square(Buffer& io) { square(io, io, LEAD_NONE, LEAD_NONE, false, false); } + void squareCERT(Buffer& io, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut) { square(io, io, leadIn, leadOut, false, false); } + void squareLL(Buffer& io, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut) { square(io, io, leadIn, leadOut, false, true); } u32 squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3); u32 squareLoop(Buffer& io, u32 from, u32 to) { return squareLoop(io, io, from, to, false); } @@ -265,7 +268,7 @@ class Gpu { void mul(Buffer& io, Buffer& inB); void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); - void modMul(Buffer& ioA, bool leadInB, Buffer& inB, bool mul3 = false); + void modMul(Buffer& ioA, enum LEAD_TYPE leadInB, Buffer& inB, bool mul3 = false); fs::path saveProof(const Args& args, const ProofSet& proofSet); std::pair readROE(); From 197928e1dba5dfa02b81e4706933afb3001e2dce Mon Sep 17 00:00:00 2001 From: george Date: Sat, 8 Nov 2025 01:05:01 +0000 Subject: [PATCH 095/115] Made the new modM31 macro a -use option. The new code is 1% faster on an RTX 5080 but 1% slower on a Titan V. --- src/Gpu.cpp | 3 ++- src/cl/base.cl | 3 +++ src/cl/math.cl | 15 +++++++++------ src/tune.cpp | 25 ++++++++++++++++++++++--- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 55882872..3b34e31c 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -274,7 +274,8 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< "TABMUL_CHAIN", "TABMUL_CHAIN31", "TABMUL_CHAIN32", - "TABMUL_CHAIN61" + "TABMUL_CHAIN61", + "MODM31" }); if (!isValid) { log("Warning: unrecognized -use key '%s'\n", k.c_str()); diff --git a/src/cl/base.cl b/src/cl/base.cl index 0ebf0fc3..51344cb5 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -128,6 +128,9 @@ G_H "group height" == SMALL_HEIGHT / NH #if !defined(TABMUL_CHAIN61) #define TABMUL_CHAIN61 0 #endif +#if !defined(MODM31) +#define MODM31 0 +#endif #if !defined(MIDDLE_CHAIN) #define MIDDLE_CHAIN 0 diff --git a/src/cl/math.cl b/src/cl/math.cl index 3d6df82c..9354259f 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -600,15 +600,18 @@ GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } #elif 1 // This version is a little sloppy. Returns values in 0..M31 range. // Internal routines to return value in 0..M31 range -//Z31 OVERLOAD modM31(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +#if MODM31 == 0 +Z31 OVERLOAD modM31(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(i32 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +#elif MODM31 == 1 Z31 OVERLOAD modM31(Z31 a) { i32 alt = a + 0x80000001; return select32(a, a, alt); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) -//Z31 OVERLOAD modM31(Z31 a) { return optional_add(a, 0x80000001); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) - -//Z31 OVERLOAD modM31(i32 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) Z31 OVERLOAD modM31(i32 a) { i32 alt = a - 0x80000001; return select32(a, a, alt); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) -//Z31 OVERLOAD modM31(i32 a) { return optional_sub(a, 0x80000001); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +#else +Z31 OVERLOAD modM31(Z31 a) { return optional_add(a, 0x80000001); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(i32 a) { return optional_sub(a, 0x80000001); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +#endif -Z31 OVERLOAD modM31(u64 a) { // a must be less than 0xFFFFFFFF7FFFFFFF +Z31 OVERLOAD modM31(u64 a) { // a must be less than 0xFFFFFFFF7FFFFFFF u32 alo = a & M31; u32 amid = (a >> 31) & M31; u32 ahi = a >> 62; diff --git a/src/tune.cpp b/src/tune.cpp index 4993eb1b..fa6dfb28 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -736,6 +736,27 @@ void Tune::tune() { args->flags["TABMUL_CHAIN61"] = to_string(best_tabmul_chain); } + // Find best MODM31 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF31) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_modm31 = 0; + u32 current_modm31 = args->value("MODM31", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 modm31 : {0, 1, 2}) { + args->flags["MODM31"] = to_string(modm31); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using MODM31=%u is %6.1f\n", fft.spec().c_str(), modm31, cost); + if (modm31 == current_modm31) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_modm31 = modm31; } + } + log("Best MODM31 is %u. Default MODM31 is 0.\n", best_modm31); + configsUpdate(current_cost, best_cost, 0.003, "MODM31", best_modm31, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MODM31"] = to_string(best_modm31); + } + // Find best UNROLL_W setting if (1) { FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; @@ -1046,10 +1067,8 @@ skip_1K_256 = 0; double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); bool isUseful = TuneEntry{cost, fft}.update(results); log("%c %6.1f %12s %9lu\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); + if (isUseful) TuneEntry::writeTuneFile(results); } } } -//GW: write results more often (in case -tune run is aborted)? - - TuneEntry::writeTuneFile(results); } From b6b6452ac52013efeff9fc4ed8c9a9bab71185f0 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 8 Nov 2025 20:11:15 +0000 Subject: [PATCH 096/115] Fixed proof generation/verification bug introduced with M31/M61 cache_grouping. --- src/Gpu.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 3b34e31c..f6183bf8 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1130,9 +1130,9 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Bu for (--p; ; --p) { for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { - fftMidIn(buf2, buf3); - tailSquare(buf3, buf2); - fftMidOut(buf2, buf3); + fftMidIn(buf2, buf3, cache_group); + tailSquare(buf3, buf2, cache_group); + fftMidOut(buf2, buf3, cache_group); } if (testBit(exp, p)) { @@ -1519,6 +1519,7 @@ tuple Gpu::measureCarry() { while (k < warmup) { square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; } From 33215a9942946038bcad7e4e1705f69e9dab8278 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 8 Nov 2025 22:50:29 +0000 Subject: [PATCH 097/115] Use mad32 instructions in csqTrig and ccubeTrig. Should help TABMUL_CHAIN31=1 which is rarely set. --- src/cl/fftbase.cl | 18 +++++++++--------- src/cl/math.cl | 27 ++++++++++++++------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index 29ddfa64..bc1fb368 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -605,31 +605,31 @@ void OVERLOAD tabMul(u32 WG, TrigFP32 trig, F2 *u, u32 n, u32 f, u32 me) { void OVERLOAD chainMul4(GF31 *u, GF31 w) { u[1] = cmul(u[1], w); - GF31 base = csq(w); + GF31 base = csqTrig(w); u[2] = cmul(u[2], base); - base = cmul(base, w); //GWBUG - see FP64 version for possible optimization + base = ccubeTrig(base, w); u[3] = cmul(u[3], base); } -void OVERLOAD chainMul8(GF31 *u, GF31 w, u32 tailSquareBcast) { +void OVERLOAD chainMul8(GF31 *u, GF31 w) { u[1] = cmul(u[1], w); - GF31 w2 = csq(w); - u[2] = cmul(u[2], w2); + GF31 base = csqTrig(w); + u[2] = cmul(u[2], base); - GF31 base = cmul (w2, w); //GWBUG - see FP64 version for many possible optimizations + base = ccubeTrig(base, w); for (int i = 3; i < 8; ++i) { u[i] = cmul(u[i], base); base = cmul(base, w); } } -void OVERLOAD chainMul(u32 len, GF31 *u, GF31 w, u32 tailSquareBcast) { +void OVERLOAD chainMul(u32 len, GF31 *u, GF31 w) { // Do a length 4 chain mul if (len == 4) chainMul4(u, w); // Do a length 8 chain mul - if (len == 8) chainMul8(u, w, tailSquareBcast); + if (len == 8) chainMul8(u, w); } void OVERLOAD shuflBigLDS(u32 WG, local GF31 *lds, GF31 *u, u32 n, u32 f) { @@ -689,7 +689,7 @@ void OVERLOAD tabMul(u32 WG, TrigGF31 trig, GF31 *u, u32 n, u32 f, u32 me) { // This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. if (TABMUL_CHAIN31) { - chainMul (n, u, trig[p], 0); + chainMul (n, u, trig[p]); return; } diff --git a/src/cl/math.cl b/src/cl/math.cl index 9354259f..eb238dfc 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -148,8 +148,8 @@ i32 optional_add(i32 a, const i32 b) { #if HAS_PTX __asm("{.reg .pred %%p;\n\t" " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 - " @%%p add.s32 %0, %0, %1;}" // if (a < 0) a = a + b - : "+r"(a) : "n"(b)); + " @%%p add.s32 %0, %0, %1;}" // if (a < 0) a = a + b + : "+r"(a) : "n"(b)); #else if (a < 0) a = a + b; #endif @@ -161,8 +161,8 @@ i32 optional_sub(i32 a, const i32 b) { #if HAS_PTX __asm("{.reg .pred %%p;\n\t" " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 - " @%%p sub.s32 %0, %0, %1;}" // if (a < 0) a = a - b - : "+r"(a) : "n"(b)); + " @%%p sub.s32 %0, %0, %1;}" // if (a < 0) a = a - b + : "+r"(a) : "n"(b)); #else if (a < 0) a = a - b; #endif @@ -174,8 +174,8 @@ i32 optional_mod(i32 a, const i32 b) { #if 0 //HAS_PTX // Not faster on 5xxx GPUs (not sure why) __asm("{.reg .pred %%p;\n\t" " setp.ge.s32 %%p, %0, %1;\n\t" // a > b - " @%%p sub.s32 %0, %0, %1;}" // if (a > b) a = a - b - : "+r"(a) : "n"(b)); + " @%%p sub.s32 %0, %0, %1;}" // if (a > b) a = a - b + : "+r"(a) : "n"(b)); #else if (a >= b) a = a - b; #endif @@ -221,10 +221,10 @@ u128 OVERLOAD mad64(u64 a, u64 b, u64 c) { "madc.hi.cc.u32 %1, %4, %6, %9;\n\t" "madc.lo.cc.u32 %2, %5, %7, 0;\n\t" "madc.hi.u32 %3, %5, %7, 0;\n\t" - "mad.lo.cc.u32 %1, %5, %6, %1;\n\t" + "mad.lo.cc.u32 %1, %5, %6, %1;\n\t" "madc.hi.cc.u32 %2, %5, %6, %2;\n\t" "addc.u32 %3, %3, 0;\n\t" - "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" + "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" "madc.hi.cc.u32 %2, %4, %7, %2;\n\t" "addc.u32 %3, %3, 0;" : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "=r"(rhi2.y) @@ -717,18 +717,19 @@ GF31 OVERLOAD cmul(GF31 a, GF31 b) { } #else GF31 OVERLOAD cmul(GF31 a, GF31 b) { - u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 - u64 k1k2 = mad32(a.x, b.y + neg(b.x), k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 - u64 k1k3 = mad32(neg(a.y), b.y + b.x, k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + u32 negbx = neg(b.x); // Negate and add b values as much as possible in case b is used several times (as in a chainmul) + u64 k1 = b.x * (u64)(a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 + u64 k1k2 = mad32(a.x, b.y + negbx, k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + u64 k1k3 = mad32(a.y, neg(b.y) + negbx, k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 return U2(modM31(k1k3), modM31(k1k2)); } #endif // Square a root of unity complex number -GF31 OVERLOAD csqTrig(GF31 a) { Z31 two_ay = a.y + a.y; return U2(modM31(1 + two_ay * (u64) neg(a.y)), modM31(a.x * (u64) two_ay)); } +GF31 OVERLOAD csqTrig(GF31 a) { u32 two_ay = a.y + a.y; return U2(modM31(mad32(two_ay, neg(a.y), (u32)1)), modM31(a.x * (u64)two_ay)); } // Cube w, a root of unity complex number, given w^2 and w -GF31 OVERLOAD ccubeTrig(GF31 sq, GF31 w) { Z31 tmp = sq.y + sq.y; return U2(modM31(tmp * (u64) neg(w.y) + w.x), modM31(tmp * (u64) w.x + neg(w.y))); } +GF31 OVERLOAD ccubeTrig(GF31 sq, GF31 w) { u32 tmp = sq.y + sq.y; return U2(modM31(mad32(tmp, neg(w.y), w.x)), modM31(mad32(tmp, w.x, neg(w.y)))); } // mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? From c08ad6a2f7292e7ca1ad04809eb38f8c03784d60 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 9 Nov 2025 00:27:08 +0000 Subject: [PATCH 098/115] Added third nontemporal option. I don't know if it will be useful. --- src/cl/base.cl | 6 +++++- src/tune.cpp | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cl/base.cl b/src/cl/base.cl index 51344cb5..c0d34481 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -231,9 +231,13 @@ ulong2 OVERLOAD U2(ulong a, ulong b) { return (ulong2) (a, b); } #define CP(x) const P(x) // Macros for non-temporal load and store (in case we later want to provide a -use option to turn this off) -#if NONTEMPORAL && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) && __has_builtin(__builtin_nontemporal_store) +// The throry behind only non-temporal reads is that kernels may end faster if they can write results to a cache rather than to slow memory. +#if NONTEMPORAL == 1 && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) && __has_builtin(__builtin_nontemporal_store) #define NTLOAD(mem) __builtin_nontemporal_load(&(mem)) #define NTSTORE(mem,val) __builtin_nontemporal_store(val, &(mem)) +#elif NONTEMPORAL == 2 && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) +#define NTLOAD(mem) __builtin_nontemporal_load(&(mem)) +#define NTSTORE(mem,val) (mem) = val #else #define NTLOAD(mem) (mem) #define NTSTORE(mem,val) (mem) = val diff --git a/src/tune.cpp b/src/tune.cpp index fa6dfb28..6427a9b0 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -515,7 +515,7 @@ void Tune::tune() { u32 current_nontemporal = args->value("NONTEMPORAL", 0); double best_cost = -1.0; double current_cost = -1.0; - for (u32 nontemporal : {0, 1}) { + for (u32 nontemporal : {0, 1, 2}) { args->flags["NONTEMPORAL"] = to_string(nontemporal); double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); From 6f4b7aa89b57f9651feece1398737be48a251483 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 9 Nov 2025 04:42:51 +0000 Subject: [PATCH 099/115] Fixed compile-time bug in non-PTX mad64 routine --- src/cl/math.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index eb238dfc..044f131b 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -231,7 +231,7 @@ u128 OVERLOAD mad64(u64 a, u64 b, u64 c) { : "r"(a2.x), "r"(a2.y), "r"(b2.x), "r"(b2.y), "r"(c2.x), "r"(c2.y)); return make_u128((u64)as_ulong(rhi2), (u64)as_ulong(rlo2)); #else - return add(mul64(a, b), c); + return add(mul64(a, b), make_u128(0, c)); #endif } From 9c6ff472b11b6c8e3b480edc645faa82a1914914 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 9 Nov 2025 05:31:56 +0000 Subject: [PATCH 100/115] Very minor optimization. Add rarely used FFTW kernel to cache_group. --- src/Gpu.cpp | 22 +++++++++++----------- src/Gpu.h | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index f6183bf8..b887a88e 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -736,10 +736,10 @@ void Gpu::fftP(Buffer& out, Buffer& in) { kfftP(out, in); } -void Gpu::fftW(Buffer& out, Buffer& in) { - if (fft.FFT_FP64 || fft.FFT_FP32) kfftW(out, in); - if (fft.NTT_GF31) kfftWGF31(out, in); - if (fft.NTT_GF61) kfftWGF61(out, in); +void Gpu::fftW(Buffer& out, Buffer& in, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) kfftW(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) kfftWGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) kfftWGF61(out, in); } void Gpu::fftMidIn(Buffer& out, Buffer& in, int cache_group) { @@ -945,12 +945,12 @@ void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buff fftMidIn(tmp2, tmp1, cache_group); tailMul(tmp1, inB, tmp2, cache_group); fftMidOut(tmp2, tmp1, cache_group); + fftW(tmp1, tmp2, cache_group); } // Register the current ROE pos as multiplication (vs. a squaring) if (mulRoePos.empty() || mulRoePos.back() < roePos) { mulRoePos.push_back(roePos); } - fftW(tmp1, tmp2); if (mul3) { carryM(ioA, tmp1); } else { carryA(ioA, tmp1); } carryB(ioA); } @@ -1175,15 +1175,15 @@ void Gpu::square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enu if (leadIn == LEAD_NONE) fftP(buf2, in); - for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { - if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf2, cache_group); - tailSquare(buf2, buf1, cache_group); - fftMidOut(buf1, buf2, cache_group); - } + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf2, cache_group); + tailSquare(buf2, buf1, cache_group); + fftMidOut(buf1, buf2, cache_group); + if (leadOut == LEAD_NONE) fftW(buf2, buf1, cache_group); + } // If leadOut is not allowed then we cannot use the faster carryFused kernel if (leadOut == LEAD_NONE) { - fftW(buf2, buf1); if (!doLL && !doMul3) { carryA(out, buf2); } else if (doLL) { diff --git a/src/Gpu.h b/src/Gpu.h index 20bc533f..86d5b200 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -227,7 +227,7 @@ class Gpu { void tailSquare(Buffer& out, Buffer& in, int cache_group = 0); void tailMul(Buffer& out, Buffer& in1, Buffer& in2, int cache_group = 0); void tailMulLow(Buffer& out, Buffer& in1, Buffer& in2, int cache_group = 0); - void fftW(Buffer& out, Buffer& in); + void fftW(Buffer& out, Buffer& in, int cache_group = 0); void carryA(Buffer& out, Buffer& in) { carryA(reinterpret_cast&>(out), in); } void carryA(Buffer& out, Buffer& in); void carryM(Buffer& out, Buffer& in); From 9076198693e3a16068c05f061ad584aa2b39b53d Mon Sep 17 00:00:00 2001 From: george Date: Mon, 10 Nov 2025 19:08:02 +0000 Subject: [PATCH 101/115] Check for needed builtins in variant 0 FFTs --- src/cl/base.cl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/cl/base.cl b/src/cl/base.cl index c0d34481..55a3ff73 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -111,6 +111,22 @@ G_H "group height" == SMALL_HEIGHT / NH #if FFT_VARIANT_H > 2 #error FFT_VARIANT_H must be between 0 and 2 #endif +// C code ensures that only AMD GPUs use FFT_VARIANT_W=0 and FFT_VARIANT_H=0. However, this does not guarantee that the OpenCL compiler supports +// the necessary amdgcn builtins. If those builtins are not present convert to variant one. +#if AMDGPU +#if !defined(__has_builtin) || !__has_builtin(__builtin_amdgcn_mov_dpp) || !__has_builtin(__builtin_amdgcn_ds_swizzle) || !__has_builtin(__builtin_amdgcn_readfirstlane) +#if FFT_VARIANT_W == 0 +#warning Missing builtins for FFT_VARIANT_W=0, switching to FFT_VARIANT_W=1 +#undef FFT_VARIANT_W +#define FFT_VARIANT_W 1 +#endif +#if FFT_VARIANT_H == 0 +#warning Missing builtins for FFT_VARIANT_H=0, switching to FFT_VARIANT_H=1 +#undef FFT_VARIANT_H +#define FFT_VARIANT_H 1 +#endif +#endif +#endif #if !defined(BIGLIT) #define BIGLIT 1 From a7bb209a4e2a20f0c1fbcf797e7f3d71d4a40525 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 15 Nov 2025 18:03:31 +0000 Subject: [PATCH 102/115] Wrote asm routines for prefetching on nVidia. They did not help. Maybe I'll figure out a may to use them profitably in the future. --- src/cl/base.cl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/cl/base.cl b/src/cl/base.cl index 55a3ff73..34fa7159 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -246,8 +246,8 @@ ulong2 OVERLOAD U2(ulong a, ulong b) { return (ulong2) (a, b); } #define P(x) global x * restrict #define CP(x) const P(x) -// Macros for non-temporal load and store (in case we later want to provide a -use option to turn this off) -// The throry behind only non-temporal reads is that kernels may end faster if they can write results to a cache rather than to slow memory. +// Macros for non-temporal load and store. The theory behind only non-temporal reads (option 2) is that with alternating buffers, +// read buffers will not be needed for quite a while, but write buffers will be needed soon. #if NONTEMPORAL == 1 && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) && __has_builtin(__builtin_nontemporal_store) #define NTLOAD(mem) __builtin_nontemporal_load(&(mem)) #define NTSTORE(mem,val) __builtin_nontemporal_store(val, &(mem)) @@ -259,6 +259,18 @@ ulong2 OVERLOAD U2(ulong a, ulong b) { return (ulong2) (a, b); } #define NTSTORE(mem,val) (mem) = val #endif +// Prefetch macros. Unused at present, I tried using them in fftMiddleInGF61 on a 5080 with no benefit. +void PREFETCHL1(const __global void *addr) { +#if HAS_PTX + __asm("prefetch.global.L1 [%0];" : : "l"(addr)); +#endif +} +void PREFETCHL2(const __global void *addr) { +#if HAS_PTX + __asm("prefetch.global.L2 [%0];" : : "l"(addr)); +#endif +} + // For reasons unknown, loading trig values into nVidia's constant cache has terrible performance #if AMDGPU typedef constant const T2* Trig; From 5e6ad1b9765bb072a5a632335023a968000812a3 Mon Sep 17 00:00:00 2001 From: george Date: Sat, 15 Nov 2025 19:23:57 +0000 Subject: [PATCH 103/115] Fixed lint issues in MINGW64 where u64 is unsigned long long rather than unsigned long --- src/FFTConfig.cpp | 3 ++- src/File.h | 8 +++++--- src/Gpu.cpp | 9 ++++++--- src/tune.cpp | 3 ++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index 85fd02bb..2308a037 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -14,6 +14,7 @@ #include #include #include +#include using namespace std; @@ -275,7 +276,7 @@ FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { if (!spec.empty()) { FFTConfig fft{spec}; if (fft.maxExp() * args.fftOverdrive < E) { - log("Warning: %s (max %lu) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); + log("Warning: %s (max %" PRIu64 ") may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); } return fft; } diff --git a/src/File.h b/src/File.h index 8419f9e1..1e1766e4 100644 --- a/src/File.h +++ b/src/File.h @@ -60,15 +60,17 @@ class File { _commit(fileno(f)); #elif defined(__APPLE__) fcntl(fileno(f), F_FULLFSYNC, 0); -#elif defined(__MSYS__) +#elif defined(__MINGW32__) || defined(__MINGW64__) + fdatasync(fileno(f)); +#elif defined(__MSYS__) // MSYS2 using CLANG64 compiler #define fileno(__F) ((__F)->_file) - fsync(fileno(f)); + fsync(fileno(f)); // This doesn't work #undef fileno #else fdatasync(fileno(f)); #endif } - + public: const std::string name; diff --git a/src/Gpu.cpp b/src/Gpu.cpp index b887a88e..d19c0c5f 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #define _USE_MATH_DEFINES #include @@ -162,8 +163,10 @@ Weights genWeights(FFTConfig fft, u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { string toLiteral(i32 value) { return to_string(value); } string toLiteral(u32 value) { return to_string(value) + 'u'; } -[[maybe_unused]] string toLiteral(i64 value) { return to_string(value) + "l"; } -[[maybe_unused]] string toLiteral(u64 value) { return to_string(value) + "ul"; } +[[maybe_unused]] string toLiteral(long value) { return to_string(value) + "l"; } +[[maybe_unused]] string toLiteral(unsigned long value) { return to_string(value) + "ul"; } +[[maybe_unused]] string toLiteral(long long value) { return to_string(value) + "ll"; } +[[maybe_unused]] string toLiteral(unsigned long long value) { return to_string(value) + "ull"; } template string toLiteral(F value) { @@ -644,7 +647,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& // Sometimes we do want to run a FFT beyond a reasonable BPW (e.g. during -ztune), and these situations // coincide with logFftSize == false if (fft.maxExp() < E) { - log("Warning: %s (max %lu) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); + log("Warning: %s (max %" PRIu64 ") may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); } } diff --git a/src/tune.cpp b/src/tune.cpp index 6427a9b0..5378a319 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -14,6 +14,7 @@ #include #include #include +#include using std::accumulate; @@ -1066,7 +1067,7 @@ skip_1K_256 = 0; double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); bool isUseful = TuneEntry{cost, fft}.update(results); - log("%c %6.1f %12s %9lu\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); + log("%c %6.1f %12s %9" PRIu64 "\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); if (isUseful) TuneEntry::writeTuneFile(results); } } From fa249e175ef5292cb8cbf06e050b38c33a468d7f Mon Sep 17 00:00:00 2001 From: george Date: Sun, 16 Nov 2025 02:30:50 +0000 Subject: [PATCH 104/115] More MINGW-64 changes where a long is 32 bits vs. everywhere else a long is 64 bits. --- Makefile | 5 ++++- src/Gpu.cpp | 9 ++++----- src/common.h | 12 ++++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index e9c62f34..26bffeda 100644 --- a/Makefile +++ b/Makefile @@ -19,9 +19,12 @@ else CXX = g++ endif +ifneq ($(findstring MINGW, $(HOST_OS)), MINGW) COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc +else # For mingw-64 use this: -#COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc -static +COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc -static +endif # -fext-numeric-literals ifeq ($(HOST_OS), Darwin) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index d19c0c5f..c943b08f 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -165,8 +165,8 @@ string toLiteral(i32 value) { return to_string(value); } string toLiteral(u32 value) { return to_string(value) + 'u'; } [[maybe_unused]] string toLiteral(long value) { return to_string(value) + "l"; } [[maybe_unused]] string toLiteral(unsigned long value) { return to_string(value) + "ul"; } -[[maybe_unused]] string toLiteral(long long value) { return to_string(value) + "ll"; } -[[maybe_unused]] string toLiteral(unsigned long long value) { return to_string(value) + "ull"; } +[[maybe_unused]] string toLiteral(long long value) { return to_string(value) + "l"; } // Yes, this looks wrong. The Mingw64 C compiler uses +[[maybe_unused]] string toLiteral(unsigned long long value) { return to_string(value) + "ul"; } // long long for 64-bits, while openCL uses long for 64 bits. template string toLiteral(F value) { @@ -209,7 +209,6 @@ string toLiteral(const string& s) { return s; } [[maybe_unused]] string toLiteral(float2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } [[maybe_unused]] string toLiteral(double2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } [[maybe_unused]] string toLiteral(int2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } -[[maybe_unused]] string toLiteral(long2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } [[maybe_unused]] string toLiteral(uint2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } [[maybe_unused]] string toLiteral(ulong2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } @@ -1143,8 +1142,8 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Bu for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { fftMidIn(buf2, buf3, cache_group); tailMulLow(buf3, buf2, buf1, cache_group); - fftMidOut(buf2, buf3, cache_group); - } + fftMidOut(buf2, buf3, cache_group); + } } if (!p) { break; } diff --git a/src/common.h b/src/common.h index 77784b05..516b099d 100644 --- a/src/common.h +++ b/src/common.h @@ -27,14 +27,14 @@ namespace fs = std::filesystem; // C code will use i64 integer data. The code that reads and writes GPU buffers will downsize the integers to 32 bits when required. typedef i64 Word; +// Create datatype names that mimic the ones used in OpenCL code using double2 = pair; using float2 = pair; -using int2 = pair; -using long2 = pair; -using uint = unsigned int; -using uint2 = pair; -using ulong = unsigned long; -using ulong2 = pair; +using int2 = pair; +using uint = u32; +using uint2 = pair; +using ulong = u64; +using ulong2 = pair; std::vector split(const string& s, char delim); From 355490982c48907b939849493996c3f5af574ead Mon Sep 17 00:00:00 2001 From: george Date: Sun, 16 Nov 2025 03:26:39 +0000 Subject: [PATCH 105/115] Changed all usages of HAS_PTX to test for needed level of CUDA support. Now we just need to automatically detect the GPU's actual level of CUDA support. --- src/cl/base.cl | 6 +++--- src/cl/carryutil.cl | 12 ++++++------ src/cl/math.cl | 24 ++++++++++++------------ 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/cl/base.cl b/src/cl/base.cl index 34fa7159..df1ef02b 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -66,7 +66,7 @@ G_H "group height" == SMALL_HEIGHT / NH #define HAS_PTX 0 #elif NVIDIAGPU #define HAS_ASM 0 -#define HAS_PTX 1 +#define HAS_PTX 1200 // Assume CUDA 12.00 support until we can figure out how to automatically determine this at runtime #else #define HAS_ASM 0 #define HAS_PTX 0 @@ -261,12 +261,12 @@ ulong2 OVERLOAD U2(ulong a, ulong b) { return (ulong2) (a, b); } // Prefetch macros. Unused at present, I tried using them in fftMiddleInGF61 on a 5080 with no benefit. void PREFETCHL1(const __global void *addr) { -#if HAS_PTX +#if HAS_PTX >= 200 // Prefetch instruction requires sm_20 support or higher __asm("prefetch.global.L1 [%0];" : : "l"(addr)); #endif } void PREFETCHL2(const __global void *addr) { -#if HAS_PTX +#if HAS_PTX >= 200 // Prefetch instruction requires sm_20 support or higher __asm("prefetch.global.L2 [%0];" : : "l"(addr)); #endif } diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index 9a78290b..6cdff372 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -24,7 +24,7 @@ typedef i32 CarryABM; // Return unsigned low bits (number of bits must be between 1 and 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) u32 OVERLOAD ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } -#elif HAS_PTX +#elif HAS_PTX >= 700 // szext instruction requires sm_70 support or higher u32 OVERLOAD ulowBits(i32 u, u32 bits) { u32 res; __asm("szext.clamp.u32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else u32 OVERLOAD ulowBits(i32 u, u32 bits) { return (((u32) u << (32 - bits)) >> (32 - bits)); } @@ -44,7 +44,7 @@ u64 OVERLOAD ulowFixedBits(u64 u, const u32 bits) { return ulowFixedBits((i64) u // Return signed low bits (number of bits must be between 1 and 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) i32 OVERLOAD lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } -#elif HAS_PTX +#elif HAS_PTX >= 700 // szext instruction requires sm_70 support or higher i32 OVERLOAD lowBits(i32 u, u32 bits) { i32 res; __asm("szext.clamp.s32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else i32 OVERLOAD lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } @@ -55,7 +55,7 @@ i64 OVERLOAD lowBits(i64 u, u32 bits) { return ((u << (64 - bits)) >> (64 - bits i64 OVERLOAD lowBits(u64 u, u32 bits) { return lowBits((i64)u, bits); } // Return signed low bits (number of bits must be between 1 and 32) -#if HAS_PTX +#if HAS_PTX // szext does not return result we are looking for if bits = 32 i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits(u, bits); } #else i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits((u64)u, bits); } @@ -65,7 +65,7 @@ i32 OVERLOAD lowBitsSafe32(u32 u, u32 bits) { return lowBitsSafe32((i32)u, bits) // Return signed low bits where number of bits is known at compile time (number of bits can be 0 to 32) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return __builtin_amdgcn_sbfe(u, 0, bits); } -#elif HAS_PTX +#elif HAS_PTX >= 700 // szext instruction requires sm_70 support or higher i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; i32 res; __asm("szext.clamp.s32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return (u << (32 - bits)) >> (32 - bits); } @@ -79,14 +79,14 @@ i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64)u, b // Extract 32 bits from a 64-bit value (starting bit offset can be 0 to 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_alignbit) i32 xtract32(i64 x, u32 bits) { return __builtin_amdgcn_alignbit(as_int2(x).y, as_int2(x).x, bits); } -#elif HAS_PTX +#elif HAS_PTX >= 320 // shf instruction requires sm_32 support or higher i32 xtract32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" : "=r"(res) : "r"(as_uint2(x).x), "r"(as_uint2(x).y), "r"(bits)); return res; } #else i32 xtract32(i64 x, u32 bits) { return x >> bits; } #endif // Extract 32 bits from a 64-bit value (starting bit offset can be 0 to 32) -#if HAS_PTX +#if HAS_PTX >= 320 // shf instruction requires sm_32 support or higher i32 xtractSafe32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" : "=r"(res) : "r"(as_uint2(x).x), "r"(as_uint2(x).y), "r"(bits)); return res; } #else i32 xtractSafe32(i64 x, u32 bits) { return x >> bits; } diff --git a/src/cl/math.cl b/src/cl/math.cl index 044f131b..513d8832 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -27,7 +27,7 @@ u64 i96_lo64(i96 val) { return as_ulong((uint2)(val.lo32, val.mid32)); } u64 i96_hi64(i96 val) { return as_ulong((uint2)(val.mid32, val.hi32)); } i96 OVERLOAD add(i96 a, i96 b) { i96 val; -#if HAS_PTX +#if HAS_PTX >= 200 // Carry instructions requires sm_20 support or higher __asm("add.cc.u32 %0, %3, %6;\n\t" "addc.cc.u32 %1, %4, %7;\n\t" "addc.u32 %2, %5, %8;" @@ -42,7 +42,7 @@ i96 OVERLOAD add(i96 a, i96 b) { i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } i96 OVERLOAD sub(i96 a, i96 b) { i96 val; -#if HAS_PTX +#if HAS_PTX >= 200 // Carry instructions requires sm_20 support or higher __asm("sub.cc.u32 %0, %3, %6;\n\t" "subc.cc.u32 %1, %4, %7;\n\t" "subc.u32 %2, %5, %8;" @@ -134,7 +134,7 @@ u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.lo64 = a.lo64 + b.lo64; val.hi // Select based on sign of first argument. This generates less PTX code, but is no faster on 5xxx GPUs i32 select32(i32 a, i32 b, i32 c) { -#if HAS_PTX +#if HAS_PTX >= 100 // slct instruction requires sm_10 support or higher i32 res; __asm("slct.s32.s32 %0, %2, %3, %1;" : "=r"(res) : "r"(a), "r"(b), "r"(c)); return res; @@ -145,7 +145,7 @@ i32 select32(i32 a, i32 b, i32 c) { // Optionally add a value if first arg is negative. i32 optional_add(i32 a, const i32 b) { -#if HAS_PTX +#if HAS_PTX >= 100 // setp/add instruction requires sm_10 support or higher __asm("{.reg .pred %%p;\n\t" " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 " @%%p add.s32 %0, %0, %1;}" // if (a < 0) a = a + b @@ -158,7 +158,7 @@ i32 optional_add(i32 a, const i32 b) { // Optionally subtract a value if first arg is negative. i32 optional_sub(i32 a, const i32 b) { -#if HAS_PTX +#if HAS_PTX >= 100 // setp/sub instruction requires sm_10 support or higher __asm("{.reg .pred %%p;\n\t" " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 " @%%p sub.s32 %0, %0, %1;}" // if (a < 0) a = a - b @@ -171,7 +171,7 @@ i32 optional_sub(i32 a, const i32 b) { // Optionally subtract a value if first arg is greater than value. i32 optional_mod(i32 a, const i32 b) { -#if 0 //HAS_PTX // Not faster on 5xxx GPUs (not sure why) +#if 0 //HAS_PTX >= 100 // setp/sub instruction requires sm_10 support or higher // Not faster on 5xxx GPUs (not sure why) __asm("{.reg .pred %%p;\n\t" " setp.ge.s32 %%p, %0, %1;\n\t" // a > b " @%%p sub.s32 %0, %0, %1;}" // if (a > b) a = a - b @@ -185,7 +185,7 @@ i32 optional_mod(i32 a, const i32 b) { // Multiply and add primitives u64 OVERLOAD mad32(u32 a, u32 b, u32 c) { -#if HAS_PTX // Same speed on TitanV, any gain may be too small to measure +#if HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Same speed on TitanV, any gain may be too small to measure u32 reslo, reshi; __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" "madc.hi.u32 %1, %2, %3, 0;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"(c)); @@ -196,7 +196,7 @@ u64 OVERLOAD mad32(u32 a, u32 b, u32 c) { } u64 OVERLOAD mad32(u32 a, u32 b, u64 c) { -#if HAS_PTX // Same speed on TitanV, any gain may be too small to measure +#if HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Same speed on TitanV, any gain may be too small to measure u32 reslo, reshi; __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" "madc.hi.u32 %1, %2, %3, %5;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"(lo32(c)), "r"(hi32(c))); @@ -207,12 +207,12 @@ u64 OVERLOAD mad32(u32 a, u32 b, u64 c) { } u128 OVERLOAD mad64(u64 a, u64 b, u64 c) { -#if 0 && HAS_PTX // Slower on TitanV and mobile 4070, don't understand why +#if 0 && HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Slower on TitanV and mobile 4070, don't understand why u64 reslo, reshi; __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" "madc.hi.u64 %1, %2, %3, 0;" : "=l"(reslo), "=l"(reshi) : "l"(a), "l"(b), "l"(u128_lo64(c))); return make_u128(reshi, reslo); -#elif HAS_PTX // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. +#elif HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. uint2 a2 = as_uint2(a); uint2 b2 = as_uint2(b); uint2 c2 = as_uint2(c); @@ -236,12 +236,12 @@ u128 OVERLOAD mad64(u64 a, u64 b, u64 c) { } u128 OVERLOAD mad64(u64 a, u64 b, u128 c) { -#if 0 && HAS_PTX // Slower on TitanV and mobile 4070, don't understand why +#if 0 && HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Slower on TitanV and mobile 4070, don't understand why u64 reslo, reshi; __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" "madc.hi.u64 %1, %2, %3, %5;" : "=l"(reslo), "=l"(reshi) : "l"(a), "l"(b), "l"(u128_lo64(c)), "l"(u128_hi64(c))); return make_u128(reshi, reslo); -#elif HAS_PTX // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. +#elif HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. uint2 a2 = as_uint2(a); uint2 b2 = as_uint2(b); uint2 clo2 = as_uint2(u128_lo64(c)); From d2d52aa08165971a4997b4d684b0fc0397a14a91 Mon Sep 17 00:00:00 2001 From: george Date: Sun, 16 Nov 2025 23:16:00 +0000 Subject: [PATCH 106/115] Fixed memory leak enqueuing marker --- src/Queue.cpp | 4 ++-- src/Queue.h | 2 +- src/clwrap.cpp | 6 ++++++ src/clwrap.h | 2 ++ 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/Queue.cpp b/src/Queue.cpp index 1dd8f011..27f461ad 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -92,7 +92,7 @@ void Queue::queueMarkerEvent() { } // Enqueue a marker for nVidia GPUs else { - clEnqueueMarkerWithWaitList(get(), 0, NULL, &markerEvent); + markerEvent = enqueueMarker(get()); markerQueued = true; queueCount = 0; } @@ -103,7 +103,7 @@ void Queue::waitForMarkerEvent() { if (!markerQueued) return; // By default, nVidia finish causes a CPU busy wait. Instead, sleep for a while. Since we know how many items are enqueued after the marker we can make an // educated guess of how long to sleep to keep CPU overhead low. - while (getEventInfo(markerEvent) != CL_COMPLETE) { + while (getEventInfo(markerEvent.get()) != CL_COMPLETE) { // There are 4, 7, or 10 kernels per squaring. Don't overestimate sleep time. Divide by much more than the number of kernels. std::this_thread::sleep_for(std::chrono::microseconds(1 + queueCount * squareTime / squareKernels / 2)); } diff --git a/src/Queue.h b/src/Queue.h index 4ed4200a..be8341d6 100644 --- a/src/Queue.h +++ b/src/Queue.h @@ -55,7 +55,7 @@ class Queue : public QueueHolder { private: // This replaces the "call queue->finish every 400 squarings" code in Gpu.cpp. Solves the busy wait on nVidia GPUs. int MAX_QUEUE_COUNT; // Queue size before a marker will be enqueued. Typically, 100 to 1000 squarings. - cl_event markerEvent; // Event associated with an enqueued marker placed in the queue every MAX_QUEUE_COUNT entries and before r/w operations. + EventHolder markerEvent; // Event associated with an enqueued marker placed in the queue every MAX_QUEUE_COUNT entries and before r/w operations. bool markerQueued; // TRUE if a marker and event have been queued int queueCount; // Count of items added to the queue since last marker int squareTime; // Time to do one squaring (in microseconds) diff --git a/src/clwrap.cpp b/src/clwrap.cpp index 46e078f0..93385e64 100644 --- a/src/clwrap.cpp +++ b/src/clwrap.cpp @@ -363,6 +363,12 @@ EventHolder fillBuf(cl_queue q, vector&& waits, return genEvent ? EventHolder{event} : EventHolder{}; } +EventHolder enqueueMarker(cl_queue q) { + cl_event event{}; + CHECK1(clEnqueueMarkerWithWaitList(q, 0, 0, &event)); + return EventHolder{event}; +} + void waitForEvents(vector&& waits) { if (!waits.empty()) { CHECK1(clWaitForEvents(waits.size(), waits.data())); diff --git a/src/clwrap.h b/src/clwrap.h index e67011b4..8b9bf9e4 100644 --- a/src/clwrap.h +++ b/src/clwrap.h @@ -106,6 +106,8 @@ EventHolder copyBuf(cl_queue queue, vector&& waits, const cl_mem src, EventHolder fillBuf(cl_queue q, vector&& waits, cl_mem buf, const void *pat, size_t patSize, size_t size, bool genEvent); +EventHolder enqueueMarker(cl_queue q); + void waitForEvents(vector&& waits); From 76e427635b5e78d279eb92bb525c36ca9c2e4e82 Mon Sep 17 00:00:00 2001 From: george Date: Tue, 18 Nov 2025 18:14:44 +0000 Subject: [PATCH 107/115] Fix for compile error on Windows AMD GPUs trying to access amdgcn builtins. --- src/cl/fftbase.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index bc1fb368..c955e803 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -87,7 +87,7 @@ void OVERLOAD chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { } -#if AMDGPU +#if AMDGPU && (FFT_VARIANT_W == 0 || FFT_VARIANT_H == 0) int bcast4(int x) { return __builtin_amdgcn_mov_dpp(x, 0, 0xf, 0xf, false); } int bcast8(int x) { return __builtin_amdgcn_ds_swizzle(x, 0x0018); } From f3326d352908db33759997231712161953584c9b Mon Sep 17 00:00:00 2001 From: george Date: Tue, 18 Nov 2025 18:39:09 +0000 Subject: [PATCH 108/115] Close .cert file before trying to delete the file. Windows file delete fails if file is open. --- src/Gpu.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/Gpu.cpp b/src/Gpu.cpp index c943b08f..3a3ab0bf 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -1970,14 +1970,15 @@ array Gpu::isCERT(const Task& task) { // Get CERT start value char fname[32]; sprintf(fname, "M%u.cert", E); - File fi = File::openReadThrow(fname); -//We need to gracefully handle the CERT file missing. There is a window in primenet.py between worktodo.txt entry and starting value download. +// Autoprimenet.py does not add the cert entry to worktodo.txt until it has successfully downloaded the .cert file. - u32 nBytes = (E - 1) / 8 + 1; - Words B = fi.readBytesLE(nBytes); - - writeIn(bufData, std::move(B)); + { // Enclosing this code in braces ensures the file will be closed by the File destructor. The later file deletion requires the file be closed in Windows. + File fi = File::openReadThrow(fname); + u32 nBytes = (E - 1) / 8 + 1; + Words B = fi.readBytesLE(nBytes); + writeIn(bufData, std::move(B)); + } Timer elapsedTimer; From 553ee7cdb97eaa3801a749e95fcbdd97dde0d41c Mon Sep 17 00:00:00 2001 From: george Date: Wed, 19 Nov 2025 18:42:30 +0000 Subject: [PATCH 109/115] Fixed compile problem on FP32+GF61 hybrid FFTs. --- src/cl/math.cl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index 513d8832..a823181f 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -316,13 +316,13 @@ void OVERLOAD cmul_a_by_b_and_conjb(T2 *res1, T2 *res2, T2 a, T2 b) { } // Square a (cos,sin) complex number. Fancy squaring returns a fancy value. Defancy squares a fancy number returning a non-fancy number. -T2 csqTrig(T2 a) { T two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } +T2 OVERLOAD csqTrig(T2 a) { T two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } T2 csqTrigFancy(T2 a) { T two_ay = mul2(a.y); return U2(-two_ay * a.y, fma(a.x, two_ay, two_ay)); } T2 csqTrigDefancy(T2 a) { T two_ay = mul2(a.y); return U2(fma (-two_ay, a.y, 1), fma(a.x, two_ay, two_ay)); } // Cube a complex number w (cos,sin) given w^2 and w. The squared input can be either fancy or not fancy. // Fancy cCube takes a fancy w argument and returns a fancy value. Defancy takes a fancy w argument and returns a non-fancy value. -T2 ccubeTrig(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } +T2 OVERLOAD ccubeTrig(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } T2 ccubeTrigFancy(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, tmp - w.y)); } T2 ccubeTrigDefancy(T2 sq, T2 w) { T tmp = mul2(sq.y); T wx = w.x + 1; return U2(fma(tmp, -w.y, wx), fma(tmp, wx, -w.y)); } @@ -430,13 +430,13 @@ void OVERLOAD cmul_a_by_b_and_conjb(F2 *res1, F2 *res2, F2 a, F2 b) { } // Square a (cos,sin) complex number. Fancy squaring returns a fancy value. Defancy squares a fancy number returning a non-fancy number. -F2 csqTrig(F2 a) { F two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } +F2 OVERLOAD csqTrig(F2 a) { F two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } F2 csqTrigFancy(F2 a) { F two_ay = mul2(a.y); return U2(-two_ay * a.y, fma(a.x, two_ay, two_ay)); } F2 csqTrigDefancy(F2 a) { F two_ay = mul2(a.y); return U2(fma (-two_ay, a.y, 1), fma(a.x, two_ay, two_ay)); } // Cube a complex number w (cos,sin) given w^2 and w. The squared input can be either fancy or not fancy. // Fancy cCube takes a fancy w argument and returns a fancy value. Defancy takes a fancy w argument and returns a non-fancy value. -F2 ccubeTrig(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } +F2 OVERLOAD ccubeTrig(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } F2 ccubeTrigFancy(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, tmp - w.y)); } F2 ccubeTrigDefancy(F2 sq, F2 w) { F tmp = mul2(sq.y); F wx = w.x + 1; return U2(fma(tmp, -w.y, wx), fma(tmp, wx, -w.y)); } From c923698bd7a2358528dc636d7bb5b29a0bffdcb3 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 19 Nov 2025 18:46:02 +0000 Subject: [PATCH 110/115] Fixed two more hybrid FFT OVERLOAD errors --- src/cl/math.cl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cl/math.cl b/src/cl/math.cl index a823181f..479238c6 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -379,10 +379,10 @@ T2 OVERLOAD addsub(T2 a) { return U2(a.x + a.y, a.x - a.y); } // computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) // which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 -T2 foo2(T2 a, T2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } +T2 OVERLOAD foo2(T2 a, T2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } // computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) -T2 foo(T2 a) { return foo2(a, a); } +T2 OVERLOAD foo(T2 a) { return foo2(a, a); } #endif @@ -493,10 +493,10 @@ F2 OVERLOAD addsub(F2 a) { return U2(a.x + a.y, a.x - a.y); } // computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) // which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 -F2 foo2(F2 a, F2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } +F2 OVERLOAD foo2(F2 a, F2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } // computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) -F2 foo(F2 a) { return foo2(a, a); } +F2 OVERLOAD foo(F2 a) { return foo2(a, a); } #endif From 6bc8b434f6045d8a7caea5f05627ca7dac34ca47 Mon Sep 17 00:00:00 2001 From: george Date: Wed, 19 Nov 2025 19:11:26 +0000 Subject: [PATCH 111/115] -tune spews warning messages on Windows with an AMD GPU because clang does not support the builtins required for variant zero. As a poor workaround, this change let's the user specify NO_ASM to bypass tuning FP64 variant zero. --- src/tune.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index 5378a319..15eedef2 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -971,16 +971,18 @@ skip_1K_256 = 0; // Only FP64 code supports variants if (variant != 202 && !FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) continue; - // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. + // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. CLANG doesn't support builtins. Let NO_ASM bypass variant zero. if (variant_W(variant) == 0) { if (!AMDGPU) continue; if (shape.width > 1024) continue; + if (args->value("NO_ASM", 0)) continue; } // Only AMD GPUs support variant zero (BCAST) and only if height <= 1024. if (variant_H(variant) == 0) { if (!AMDGPU) continue; if (shape.height > 1024) continue; + if (args->value("NO_ASM", 0)) continue; } // Reject shapes that won't be used to test exponents in the user's desired range From c363095a8cbac0186eb755f32afe6ca1e2b58bac Mon Sep 17 00:00:00 2001 From: george Date: Wed, 10 Dec 2025 02:52:39 +0000 Subject: [PATCH 112/115] Added INPLACE=1 option. Documentation says INPLACE=2 will choose a good AMD memory layout. I have been unable to find a memory layout nearly as good as the INPLACE=0 layout. --- src/Args.cpp | 12 +- src/Gpu.cpp | 182 +++++++++++++------ src/Gpu.h | 25 +-- src/cl/fft-middle.cl | 60 +++++++ src/cl/fftheight.cl | 4 + src/cl/fftmiddlein.cl | 198 ++++++++++++++++++++ src/cl/fftmiddleout.cl | 207 ++++++++++++++++++++- src/cl/middle.cl | 399 +++++++++++++++++++++++++++++++++++++++++ src/tune.cpp | 105 ++++++----- 9 files changed, 1074 insertions(+), 118 deletions(-) diff --git a/src/Args.cpp b/src/Args.cpp index 4448d8ec..041202ad 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -186,6 +186,8 @@ named "config.txt" in the prpll run directory. 2 = calculate from scratch, no memory read 1 = calculate using one complex multiply from cached memory and uncached memory 0 = read trig values from memory + -use INPLACE=n : Perform tranforms in-place. Great if the reduced memory usage fits in the GPU's L2 cache. + 0 = not in-place, 1 = nVidia friendly access pattern, 2 = AMD friendly access pattern. -use PAD= : insert pad bytes to possibly improve memory access patterns. Val is number bytes to pad. -use MIDDLE_IN_LDS_TRANSPOSE=0|1 : Transpose values in local memory before writing to global memory -use MIDDLE_OUT_LDS_TRANSPOSE=0|1 : Transpose values in local memory before writing to global memory @@ -198,11 +200,11 @@ named "config.txt" in the prpll run directory. -tune : Looks for best settings to include in config.txt. Times many FFTs to find fastest one to test exponents -- written to tune.txt. An -fft can be given on the command line to limit which FFTs are timed. Options are not required. If present, the options are a comma separated list from below. - noconfig - Skip timings to find best config.txt settings - fp64 - Tune for settings that affect FP64 FFTs. Time FP64 FFTs for tune.txt. - ntt - Tune for settings that affect integer NTTs. Time integer NTTs for tune.txt. - minexp= - Time FFTs to find the best one for exponents greater than . - maxexp= - Time FFTs to find the best one for exponents less than . + noconfig - Skip timings to find best config.txt settings + fp64 - Tune for settings that affect FP64 FFTs. Time FP64 FFTs for tune.txt. + ntt - Tune for settings that affect integer NTTs. Time integer NTTs for tune.txt. + minexp= - Time FFTs to find the best one for exponents greater than . + maxexp= - Time FFTs to find the best one for exponents less than . -device : select the GPU at position N in the list of devices -uid : select the GPU with the given UID (on ROCm/AMDGPU, Linux) -pci : select the GPU with the given PCI BDF, e.g. "0c:00.0" diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 3a3ab0bf..44656fa1 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -228,7 +228,7 @@ constexpr bool isInList(const string& s, initializer_list list) { } string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector& extraConf, u32 E, bool doLog, - bool &tail_single_wide, bool &tail_single_kernel, u32 &pad_size) { + bool &tail_single_wide, bool &tail_single_kernel, u32 &in_place, u32 &pad_size) { map config; // Highest priority is the requested "extra" conf @@ -245,6 +245,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< // Default value for -use options that must also be parsed in C++ code tail_single_wide = 0, tail_single_kernel = 1; // Default tailSquare is double-wide in one kernel + in_place = 0; // Default is not in-place pad_size = isAmdGpu(id) ? 256 : 0; // Default is 256 bytes for AMD, 0 for others // Validate -use options @@ -265,6 +266,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< "CARRY64", "BIGLIT", "NONTEMPORAL", + "INPLACE", "PAD", "MIDDLE_IN_LDS_TRANSPOSE", "MIDDLE_OUT_LDS_TRANSPOSE", @@ -290,6 +292,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< if (atoi(v.c_str()) == 2) tail_single_wide = 0, tail_single_kernel = 1; if (atoi(v.c_str()) == 3) tail_single_wide = 0, tail_single_kernel = 0; } + if (k == "INPLACE") in_place = atoi(v.c_str()); if (k == "PAD") pad_size = atoi(v.c_str()); } @@ -348,7 +351,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< // The openCL code needs to know the offset to the data and trig values. Distances are in "number of double2 values". if (fft.FFT_FP64 && fft.NTT_GF31) { // GF31 data is located after the FP64 data. Compute size of the FP64 data and trigs. - defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); @@ -356,25 +359,25 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< else if (fft.FFT_FP32 && fft.NTT_GF31 && fft.NTT_GF61) { // GF31 and GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. u32 sz1, sz2, sz3, sz4; - defines += toDefine("DISTGF31", sz1 = FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF31", sz1 = FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); defines += toDefine("DISTWTRIGGF31", sz2 = SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF31", sz3 = MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF31", sz4 = SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); - defines += toDefine("DISTGF61", sz1 + GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF61", sz1 + GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); defines += toDefine("DISTWTRIGGF61", sz2 + SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF61", sz3 + MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF61", sz4 + SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); } else if (fft.FFT_FP32 && fft.NTT_GF31) { // GF31 data is located after the FP32 data. Compute size of the FP32 data and trigs. - defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); } else if (fft.FFT_FP32 && fft.NTT_GF61) { // GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. - defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); defines += toDefine("DISTWTRIGGF61", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); @@ -385,7 +388,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< defines += toDefine("DISTMTRIGGF31", 0); defines += toDefine("DISTHTRIGGF31", 0); // GF61 data is located after the GF31 data. Compute size of the GF31 data and trigs. - defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, pad_size) / 2); + defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); defines += toDefine("DISTWTRIGGF61", SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); @@ -529,7 +532,7 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& nW(fft.shape.nW()), nH(fft.shape.nH()), useLongCarry{args.carry == Args::CARRY_LONG}, - compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, tail_single_wide, tail_single_kernel, pad_size)}, + compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, tail_single_wide, tail_single_kernel, in_place, pad_size)}, #define K(name, ...) name(#name, &compiler, profile.make(#name), queue, __VA_ARGS__) @@ -630,9 +633,9 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& BUF(bufROE, ROE_SIZE), BUF(bufStatsCarry, CARRY_SIZE), - BUF(buf1, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, pad_size)), - BUF(buf2, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, pad_size)), - BUF(buf3, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, pad_size)), + BUF(buf1, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, in_place, pad_size)), + BUF(buf2, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, in_place, pad_size)), + BUF(buf3, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, in_place, pad_size)), #undef BUF statsBits{u32(args.value("STATS", 0))}, @@ -941,20 +944,30 @@ vector Gpu::readData() { return readAndCompress(bufData); } // out := inA * inB; inB is preserved void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { + if (!in_place) { + fftP(tmp2, ioA); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + fftMidIn(tmp1, tmp2, cache_group); + tailMul(tmp2, inB, tmp1, cache_group); + fftMidOut(tmp1, tmp2, cache_group); + fftW(tmp2, tmp1, cache_group); + } + } + else { fftP(tmp1, ioA); - for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { - fftMidIn(tmp2, tmp1, cache_group); - tailMul(tmp1, inB, tmp2, cache_group); - fftMidOut(tmp2, tmp1, cache_group); - fftW(tmp1, tmp2, cache_group); + fftMidIn(tmp1, tmp1, cache_group); + tailMul(tmp1, inB, tmp1, cache_group); + fftMidOut(tmp1, tmp1, cache_group); + fftW(tmp2, tmp1, cache_group); } + } - // Register the current ROE pos as multiplication (vs. a squaring) - if (mulRoePos.empty() || mulRoePos.back() < roePos) { mulRoePos.push_back(roePos); } + // Register the current ROE pos as multiplication (vs. a squaring) + if (mulRoePos.empty() || mulRoePos.back() < roePos) { mulRoePos.push_back(roePos); } - if (mul3) { carryM(ioA, tmp1); } else { carryA(ioA, tmp1); } - carryB(ioA); + if (mul3) { carryM(ioA, tmp2); } else { carryA(ioA, tmp2); } + carryB(ioA); } void Gpu::mul(Buffer& io, Buffer& buf1) { @@ -964,13 +977,18 @@ void Gpu::mul(Buffer& io, Buffer& buf1) { // out := inA * inB; void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { - modMul(ioA, LEAD_NONE, inB, mul3); + modMul(ioA, inB, LEAD_NONE, mul3); }; // out := inA * inB; inB will end up in buf1 in the LEAD_MIDDLE state -void Gpu::modMul(Buffer& ioA, enum LEAD_TYPE leadInB, Buffer& inB, bool mul3) { - if (leadInB == LEAD_NONE) fftP(buf2, inB); - if (leadInB != LEAD_MIDDLE) fftMidIn(buf1, buf2); +void Gpu::modMul(Buffer& ioA, Buffer& inB, enum LEAD_TYPE leadInB, bool mul3) { + if (!in_place) { + if (leadInB == LEAD_NONE) fftP(buf2, inB); + if (leadInB != LEAD_MIDDLE) fftMidIn(buf1, buf2); + } else { + if (leadInB == LEAD_NONE) fftP(buf1, inB); + if (leadInB != LEAD_MIDDLE) fftMidIn(buf1, buf1); + } mul(ioA, buf1, buf2, buf3, mul3); }; @@ -1123,32 +1141,51 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Bu if (exp == 0) { bufInOut.set(1); } else if (exp > 1) { - fftP(buf3, bufInOut); - fftMidIn(buf2, buf3); + if (!in_place) { + fftP(buf3, bufInOut); + fftMidIn(buf2, buf3); + } else { + fftP(buf2, bufInOut); + fftMidIn(buf2, buf2); + } fftHin(buf1, buf2); // save "base" to buf1 + bool midInAlreadyDone = 1; int p = 63; while (!testBit(exp, p)) { --p; } for (--p; ; --p) { for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { - fftMidIn(buf2, buf3, cache_group); - tailSquare(buf3, buf2, cache_group); - fftMidOut(buf2, buf3, cache_group); + if (!in_place) { + if (!midInAlreadyDone) fftMidIn(buf2, buf3, cache_group); + tailSquare(buf3, buf2, cache_group); + fftMidOut(buf2, buf3, cache_group); + } else { + if (!midInAlreadyDone) fftMidIn(buf2, buf2, cache_group); + tailSquare(buf2, buf2, cache_group); + fftMidOut(buf2, buf2, cache_group); + } } + midInAlreadyDone = 0; if (testBit(exp, p)) { - doCarry(buf3, buf2); + doCarry(buf3, buf2, bufInOut); for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { - fftMidIn(buf2, buf3, cache_group); - tailMulLow(buf3, buf2, buf1, cache_group); - fftMidOut(buf2, buf3, cache_group); + if (!in_place) { + fftMidIn(buf2, buf3, cache_group); + tailMulLow(buf3, buf2, buf1, cache_group); + fftMidOut(buf2, buf3, cache_group); + } else { + fftMidIn(buf2, buf2, cache_group); + tailMulLow(buf2, buf2, buf1, cache_group); + fftMidOut(buf2, buf2, cache_group); + } } } if (!p) { break; } - doCarry(buf3, buf2); + doCarry(buf3, buf2, bufInOut); } fftW(buf3, buf2); @@ -1158,30 +1195,63 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Bu } // does either carryFused() or the expanded version depending on useLongCarry -void Gpu::doCarry(Buffer& out, Buffer& in) { - if (useLongCarry) { - fftW(out, in); - carryA(in, out); - carryB(in); - fftP(out, in); +void Gpu::doCarry(Buffer& out, Buffer& in, Buffer& tmp) { + if (!in_place) { + if (useLongCarry) { + fftW(out, in); + carryA(tmp, out); + carryB(tmp); + fftP(out, tmp); + } else { + carryFused(out, in); + } } else { - carryFused(out, in); + if (useLongCarry) { + fftW(out, in); + carryA(tmp, out); + carryB(tmp); + fftP(in, tmp); + } else { + carryFused(in, in); + } } } +// Use buf1 and buf2 to do a single squaring. void Gpu::square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut, bool doMul3, bool doLL) { // leadOut = LEAD_MIDDLE is not supported (slower than LEAD_WIDTH) assert(leadOut != LEAD_MIDDLE); // LL does not do Mul3 assert(!(doMul3 && doLL)); - if (leadIn == LEAD_NONE) fftP(buf2, in); + // Not in place FFTs use buf1 and buf2 in a "ping pong" fashion. + // If leadIn is LEAD_NONE, in contains the input data, squaring starts at fftP + // If leadIn is LEAD_WIDTH, buf2 contains the input data, squaring starts at fftMidIn + // If leadIn is LEAD_MIDDLE, buf1 contains the input data, squaring starts at tailSquare + // If leadOut is LEAD_WIDTH, then will buf2 contain the output of carryFused -- to be used as input to the next squaring. + if (!in_place) { + if (leadIn == LEAD_NONE) fftP(buf2, in); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf2, cache_group); + tailSquare(buf2, buf1, cache_group); + fftMidOut(buf1, buf2, cache_group); + if (leadOut == LEAD_NONE) fftW(buf2, buf1, cache_group); + } + } - for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { - if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf2, cache_group); - tailSquare(buf2, buf1, cache_group); - fftMidOut(buf1, buf2, cache_group); - if (leadOut == LEAD_NONE) fftW(buf2, buf1, cache_group); + // In place FFTs use buf1. + // If leadIn is LEAD_NONE, in contains the input data, squaring starts at fftP + // If leadIn is LEAD_WIDTH, buf1 contains the input data, squaring starts at fftMidIn + // If leadIn is LEAD_MIDDLE, buf1 contains the input data, squaring starts at tailSquare + // If leadOut is LEAD_WIDTH, then buf1 will contain the output of carryFused -- to be used as input to the next squaring. + else { + if (leadIn == LEAD_NONE) fftP(buf1, in); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf1, cache_group); + tailSquare(buf1, buf1, cache_group); + fftMidOut(buf1, buf1, cache_group); + if (leadOut == LEAD_NONE) fftW(buf2, buf1, cache_group); + } } // If leadOut is not allowed then we cannot use the faster carryFused kernel @@ -1201,9 +1271,9 @@ void Gpu::square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enu assert(!useLongCarry); assert(!doMul3); if (doLL) { - carryFusedLL(buf2, buf1); + carryFusedLL(in_place ? buf1 : buf2, buf1); } else { - carryFused(buf2, buf1); + carryFused(in_place ? buf1 : buf2, buf1); } } } @@ -1511,7 +1581,7 @@ tuple Gpu::measureCarry() { } enum LEAD_TYPE leadIn = LEAD_NONE; - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; @@ -1541,7 +1611,7 @@ tuple Gpu::measureCarry() { if (k >= iters) { break; } - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1585,7 +1655,7 @@ tuple Gpu::measureROE(bool quick) { } enum LEAD_TYPE leadIn = LEAD_NONE; - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; @@ -1615,7 +1685,7 @@ tuple Gpu::measureROE(bool quick) { if (k >= iters) { break; } - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1654,7 +1724,7 @@ double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest assert(dataResidue() == state.res64); enum LEAD_TYPE leadIn = LEAD_NONE; - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; @@ -1684,7 +1754,7 @@ double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest if (k >= iters) { break; } - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1771,7 +1841,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { if (skipNextCheckUpdate) { skipNextCheckUpdate = false; } else if (k % blockSize == 0) { - modMul(bufCheck, leadIn, bufData); + modMul(bufCheck, bufData, leadIn); leadIn = LEAD_MIDDLE; } diff --git a/src/Gpu.h b/src/Gpu.h index 86d5b200..fc5166f3 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -171,6 +171,7 @@ class Gpu { // Copy of some -use options needed for Kernel, Trig, and Weights initialization bool tail_single_wide; // TailSquare processes one line at a time bool tail_single_kernel; // TailSquare does not use a separate kernel for line zero + u32 in_place; // Should GPU perform transform in-place. 1 = nVidia friendly memory layout, 2 = AMD friendly. u32 pad_size; // Pad size in bytes as specified on the command line or config.txt. Maximum value is 512. // Twiddles: trigonometry constant buffers, used in FFTs. @@ -262,13 +263,13 @@ class Gpu { void writeState(u32 k, const vector& check, u32 blockSize); // does either carrryFused() or the expanded version depending on useLongCarry - void doCarry(Buffer& out, Buffer& in); + void doCarry(Buffer& out, Buffer& in, Buffer& tmp); void mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3 = false); void mul(Buffer& io, Buffer& inB); void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); - void modMul(Buffer& ioA, enum LEAD_TYPE leadInB, Buffer& inB, bool mul3 = false); + void modMul(Buffer& ioA, Buffer& inB, enum LEAD_TYPE leadInB, bool mul3 = false); fs::path saveProof(const Args& args, const ProofSet& proofSet); std::pair readROE(); @@ -342,13 +343,13 @@ class Gpu { // Compute the size of an FFT/NTT data buffer depending on the FFT/NTT float/prime. Size is returned in units of sizeof(double). // Data buffers require extra space for padding. We can probably tighten up the amount of extra memory allocated. -// The worst case seems to be MIDDLE=4, PAD_SIZE=512. - -#define MID_ADJUST(size,M,pad) ((pad == 0 || M != 4) ? (size) : (size) * 5/4) -#define PAD_ADJUST(N,M,pad) MID_ADJUST(pad == 0 ? N : pad <= 128 ? 9*N/8 : pad <= 256 ? 5*N/4 : 3*N/2, M, pad) -#define FP64_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) -#define FP32_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(float) / sizeof(double) -#define GF31_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(uint) / sizeof(double) -#define GF61_DATA_SIZE(W,M,H,pad) PAD_ADJUST(W*M*H*2, M, pad) * sizeof(ulong) / sizeof(double) -#define TOTAL_DATA_SIZE(fft,W,M,H,pad) fft.FFT_FP64 * FP64_DATA_SIZE(W,M,H,pad) + fft.FFT_FP32 * FP32_DATA_SIZE(W,M,H,pad) + \ - fft.NTT_GF31 * GF31_DATA_SIZE(W,M,H,pad) + fft.NTT_GF61 * GF61_DATA_SIZE(W,M,H,pad) +// The worst case seems to be !INPLACE, MIDDLE=4, PAD_SIZE=512. + +#define MID_ADJUST(size,M,pad) ((pad == 0 || M != 4) ? (size) : (size) * 5/4) +#define PAD_ADJUST(N,M,inplace,pad) (inplace ? 3*N/2 : MID_ADJUST(pad == 0 ? N : pad <= 128 ? 9*N/8 : pad <= 256 ? 5*N/4 : 3*N/2, M, pad)) +#define FP64_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) +#define FP32_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) * sizeof(float) / sizeof(double) +#define GF31_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) * sizeof(uint) / sizeof(double) +#define GF61_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) * sizeof(ulong) / sizeof(double) +#define TOTAL_DATA_SIZE(fft,W,M,H,inplace,pad) (int)fft.FFT_FP64 * FP64_DATA_SIZE(W,M,H,inplace,pad) + (int)fft.FFT_FP32 * FP32_DATA_SIZE(W,M,H,inplace,pad) + \ + (int)fft.NTT_GF31 * GF31_DATA_SIZE(W,M,H,inplace,pad) + (int)fft.NTT_GF61 * GF61_DATA_SIZE(W,M,H,inplace,pad) diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index 26638cd4..31db8bfc 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -345,6 +345,22 @@ void OVERLOAD middleShuffleWrite(global T2 *out, T2 *u, u32 workgroupSize, u32 b for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } } +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local T2 *lds, T2 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} + #endif @@ -604,6 +620,20 @@ void OVERLOAD middleShuffleWrite(global F2 *out, F2 *u, u32 workgroupSize, u32 b for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } } +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} #endif @@ -709,6 +739,21 @@ void OVERLOAD middleShuffleWrite(global GF31 *out, GF31 *u, u32 workgroupSize, u for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } } +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} + #endif @@ -835,6 +880,21 @@ void OVERLOAD middleShuffleWrite(global GF61 *out, GF61 *u, u32 workgroupSize, u for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } } +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} + #endif diff --git a/src/cl/fftheight.cl b/src/cl/fftheight.cl index a2f67cb2..69fa75b8 100644 --- a/src/cl/fftheight.cl +++ b/src/cl/fftheight.cl @@ -8,7 +8,11 @@ #error SMALL_HEIGHT must be one of: 256, 512, 1024, 4096 #endif +#if !INPLACE u32 transPos(u32 k, u32 middle, u32 width) { return k / width + k % width * middle; } +#else +u32 transPos(u32 k, u32 middle, u32 width) { return k; } +#endif #if FFT_FP64 diff --git a/src/cl/fftmiddlein.cl b/src/cl/fftmiddlein.cl index 18b0835b..d1788799 100644 --- a/src/cl/fftmiddlein.cl +++ b/src/cl/fftmiddlein.cl @@ -5,6 +5,8 @@ #include "fft-middle.cl" #include "middle.cl" +#if !INPLACE // Original implementation (not in place) + #if FFT_FP64 KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { @@ -215,3 +217,199 @@ KERNEL(IN_WG) fftMiddleInGF61(P(T2) out, CP(T2) in, Trig trig) { } #endif + + + + + + +#else // in place transpose + +#if FFT_FP64 + +KERNEL(256) fftMiddleIn(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + T2 u[MIDDLE]; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = g / 131072; // A super tiny benefit (much smaller than margin of error) on TitanV +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, in, y, x); + + middleMul2(u, x, y, 1, trig); + + fft_MIDDLE(u); + + middleMul(u, y, trig); + + // Transpose the x and y values + local T2 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(in + zerohack, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(256) fftMiddleIn(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + F2 u[MIDDLE]; + + P(F2) inF2 = (P(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = 0; // Need to test if g / 131072 is of any benefit +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 (for FP64, FP32 untimed) +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, inF2, y, x); + + middleMul2(u, x, y, 1, trigF2); + + fft_MIDDLE(u); + + middleMul(u, y, trigF2); + + // Transpose the x and y values + local F2 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(inF2 + zerohack, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(256) fftMiddleInGF31(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF31 u[MIDDLE]; + + P(GF31) in31 = (P(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = 0; // Need to test if g / 131072 is of any benefit +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 (for FP64, GF31 untimed) +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, in31, y, x); + + middleMul2(u, x, y, trig31); + + fft_MIDDLE(u); + + middleMul(u, y, trig31); + + // Transpose the x and y values + local GF31 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(in31 + zerohack, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(256) fftMiddleInGF61(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF61 u[MIDDLE]; + + P(GF61) in61 = (P(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = 0; // Need to test if g / 131072 is of any benefit +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 (for FP64, GF61 untimed) +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, in61, y, x); + + middleMul2(u, x, y, trig61); + + fft_MIDDLE(u); + + middleMul(u, y, trig61); + + // Transpose the x and y values + local GF61 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(in61 + zerohack, u, y, x); +} + +#endif + +#endif diff --git a/src/cl/fftmiddleout.cl b/src/cl/fftmiddleout.cl index 43e3e48f..a7263dcd 100644 --- a/src/cl/fftmiddleout.cl +++ b/src/cl/fftmiddleout.cl @@ -5,9 +5,11 @@ #include "fft-middle.cl" #include "middle.cl" +#if !INPLACE // Original implementation (not in place) + #if FFT_FP64 -KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { +KERNEL(OUT_WG) fftMiddleOut(P(T2) out, CP(T2) in, Trig trig) { T2 u[MIDDLE]; u32 SIZEY = OUT_WG / OUT_SIZEX; @@ -68,7 +70,7 @@ KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { #if FFT_FP32 -KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { +KERNEL(OUT_WG) fftMiddleOut(P(T2) out, CP(T2) in, Trig trig) { F2 u[MIDDLE]; CP(F2) inF2 = (CP(F2)) in; @@ -133,7 +135,7 @@ KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { #if NTT_GF31 -KERNEL(OUT_WG) fftMiddleOutGF31(P(T2) out, P(T2) in, Trig trig) { +KERNEL(OUT_WG) fftMiddleOutGF31(P(T2) out, CP(T2) in, Trig trig) { GF31 u[MIDDLE]; CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); @@ -192,7 +194,7 @@ KERNEL(OUT_WG) fftMiddleOutGF31(P(T2) out, P(T2) in, Trig trig) { #if NTT_GF61 -KERNEL(OUT_WG) fftMiddleOutGF61(P(T2) out, P(T2) in, Trig trig) { +KERNEL(OUT_WG) fftMiddleOutGF61(P(T2) out, CP(T2) in, Trig trig) { GF61 u[MIDDLE]; CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); @@ -243,3 +245,200 @@ KERNEL(OUT_WG) fftMiddleOutGF61(P(T2) out, P(T2) in, Trig trig) { } #endif + + + +#else // in place transpose + +#if FFT_FP64 + +KERNEL(256) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + T2 u[MIDDLE]; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, in, y, x); + + middleMul(u, x, trig); + + fft_MIDDLE(u); + + // FFT results come out multiplied by the FFT length (NWORDS). Also, for performance reasons + // weights and invweights are doubled meaning we need to divide by another 2^2 and 2^2. + // Finally, roundoff errors are sometimes improved if we use the next lower double precision + // number. This may be due to roundoff errors introduced by applying inexact TWO_TO_N_8TH weights. + double factor = 1.0 / (4 * 4 * NWORDS); + + middleMul2(u, y, x, factor, trig); + + // Transpose the x and y values + local T2 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(out, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(256) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + F2 u[MIDDLE]; + + P(F2) inF2 = (P(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, inF2, y, x); + + middleMul(u, x, trigF2); + + fft_MIDDLE(u); + + // FFT results come out multiplied by the FFT length (NWORDS). Also, for performance reasons + // weights and invweights are doubled meaning we need to divide by another 2^2 and 2^2. + // Finally, roundoff errors are sometimes improved if we use the next lower double precision + // number. This may be due to roundoff errors introduced by applying inexact TWO_TO_N_8TH weights. + double factor = 1.0 / (4 * 4 * NWORDS); + + middleMul2(u, y, x, factor, trigF2); + + // Transpose the x and y values + local F2 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(outF2, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(256) fftMiddleOutGF31(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF31 u[MIDDLE]; + + P(GF31) in31 = (P(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, in31, y, x); + + middleMul(u, x, trig31); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig31); + + // Transpose the x and y values + local GF31 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(out31, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(256) fftMiddleOutGF61(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF61 u[MIDDLE]; + + P(GF61) in61 = (P(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, in61, y, x); + + middleMul(u, x, trig61); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig61); + + // Transpose the x and y values + local GF61 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(out61, u, y, x); +} + +#endif + +#endif diff --git a/src/cl/middle.cl b/src/cl/middle.cl index ac649881..263937f3 100644 --- a/src/cl/middle.cl +++ b/src/cl/middle.cl @@ -34,6 +34,8 @@ #define MIDDLE_OUT_LDS_TRANSPOSE 1 #endif +#if !INPLACE // Original implementation (not in place) + #if FFT_FP64 || NTT_GF61 //**************************************************************************************** @@ -669,3 +671,400 @@ void OVERLOAD readCarryFusedLine(CP(GF61) in, GF61 *u, u32 line) { } #endif + + + + + + +#else // New implementation (in-place) + +// Goals: +// 1) In-place transpose. Rather than "ping-pong"ing buffers, an in-place transpose uses half as much memory. This may allow +// the entire FFT/NTT data set to reside in the L2 cache on upper end consumer GPUs (circa 2025) which can have 64MB or larger L2 caches. +// 2) We want to have distribute the carryFused and/or tailSquare memory in the L2 cache with minimal cache line collisions. The hope is to (one day) do +// fftMiddleOut/carryFused/fftMiddleIn or fftMiddleIn/tailSquare/fftMiddleOut in L2 cache-sized chunks to minimize the slowest memory accesses. +// The cost of extra kernel launches may negate any L2 cache benefits. +// 3) We use swizzling and/or modest padding to reduce carryFused L2 cache line collisions. Several different memory layouts and padding were tried +// on nVidia Titan V and AMD Radeon VII to find the fastest in-place layout and padding scheme. Hopefully, these will schemes will work well +// on later generation GPUs with different L2 cache dimensions (size and "number-of-ways"). +// 4) Apparently cache line collisions in the L1 cache also adversely affect timings. The L1 cache may have a different cache line size and number-of-ways +// which makes padding tuning a bit difficult. This is especially true on AMD which has a very strange channel & banks partitioning of memory accesses. +// 5) Manufacturers are not very good about documenting L1 & L2 cache configurations. nVidia GPUs seem to have a L2 cache line size of 128 bytes. +// AMD documentation seems to indicate 64 or 128 byte cache lines. However, other documentation indicates 256 byte reads are to be preferred. +// Clinfo on an rx9070 says the L2 cache line size is 256 bytes. Thus, our goal is to target cache line size of 256 bytes. +// +// Here is the proposed memory layout for a 512 x 8 x 512 = 2M complex FFT using two doubles (16 bytes). PRPLL calls this a 4M FFT for the end user. +// The in-place transpose done by fftMiddleIn and fftMiddleOut works on a grid of 16 values from carryFused (multiples of 4K) and 16 values +// from tailSquare (multiples of 1). Sixteen 16-byte values = one 256 byte cache line or two 128 byte cache lines. +// +// tailSquare memory layout (also fftMiddleIn output layout and fftMiddleOut input layout). +// A HEIGHT=512 tail line is 512*16 bytes = 8KB. There are 2M/512=4K lines. Lines k and N-k are processed together. +// A 2MB L2 cache can fit 256 tail lines. So the first group of tail lines output by fftMiddleIn can be (mostly) paired with the last group of +// tail lines output by fftMiddleIn for tailSquare to process. 128 tail lines = 64K FFT values = 1MB. +// Here is the memory layout of FFT data. +// 0..511 The first tailSquare line (tail line 0). SMALL_HEIGHT values * 16 bytes (8KB). +// 4K The tailSquare line starting with FFT data element 4K (tail line 1). +// .. +// 60K These 16 lines will form 16x16 "transpose blocks". +// 64K Follow the same pattern starting at FFT data element 64K (tail line 16). Note: We have tried placing the "middle" lines next. +// ... +// 2M-4K The last "width" line (tail line MIDDLE*(WIDTH-1)). +// 512.. The tailSquare line starting with the first "middle" value (tail line WIDTH, i.e. 512), followed again by WIDTH lines 4K apart. +// ... +// 3.5K The tailSquare line containing the last middle value that fftMiddleOut will need (tail line (MIDDLE-1)*WIDTH) +// ... +// +// fftMiddleOut uses 16 sets of 16 threads to transpose 16x16 blocks of 16-byte values. Those 256 threads also process the MIDDLE blocks. For example, +// the 256 threads read 0,1..15, 4K+0,4K+1..4K+15, ..., 60K+0..60K+15 to form one 16x16 block. For MIDDLE processing, those 256 threads also read +// the seven 16x16 blocks beginning at 512, 1024, ... 3584. +// +// After transposing, the memory layout of fftMiddleOut output (also carryFused input & output and fftMiddleIn layout) is: +// 0,4K..60K 16..16+60K, 32..32+60K, ..., 496..496+60K (SMALL_HEIGHT/16=32) groups of 16 values * 16 bytes (8KB) +// ... +// 15.. +// 64K After the above 16 lines (128KB), follow the same pattern starting at FFT data element 64K +// ... +// 2.5M-64K +// 512.. After WIDTH lines above, follow the same pattern starting at the first "middle" value +// ... +// 3.5K The last "middle" value +// ... +// +// carryFused reads FFT data values that are 4K apart. The first 16 are in a single 256 byte cache line. The next 16 (starting at 64K) occurs 128KB later. +// This large power of two stride could cause cache collisions depending on the cache layout. If so, padding or swizzling MAY be very useful. +// carryFused does not reuse any read in data, so in theory its no big deal if a cache line is evicted before writing the result back out. +// I'm not sure if cache hardware cares about having multiple reads to the same cache line in-flight. NOTE: One minor advantage to eliminating cache +// conflicts is that carryFused likely processes lines in order or close to it. If we next process MiddleIn/tailSquare/MiddleOut in a 2MB cache as +// described above, it will appreciate the last 128 lines being in the L2 cache from carryfused. +// +// Back to carryfused. What might be an optimal padding scheme? Say that 64 carryfuseds are active simultaneously (probably more). +// The +1s are +8KB +// The +16s are +256B +// The +64Ks are at +128KB +// This leaves the "columns" starting at +1KB unused - suggesting a pad of +1KB before the 64Ks would yield a better distribution in the L2 cache. +// However, if we have say an 8-way 16MB L2 cache then each way contains 2MB. If so, we'd want to pad 1KB before the 16th 64K FFT data value. + +#if FFT_FP64 || NTT_GF61 + +//**************************************************************************************** +// Pair of routines to read/write data to/from carryFused +//**************************************************************************************** + +#if INPLACE == 1 // nVidia friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK (SMALL_HEIGHT + 16) // Pad 256 bytes +//#define SIZEW (16 * SIZEA + 16) // Pad 256 bytes +//#define SIZEM (MIDDLE * SIZEB + (1 - (MIDDLE & 1)) * 16) // Pad 256 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW (16 * SIZEBLK + 16) // Pad 256 bytes +#define SIZEM (WIDTH / 16 * SIZEW + 16) // Pad 256 bytes +#define SWIZ(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#else // AMD friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK (SMALL_HEIGHT + 16) // Pad 256 bytes +//#define SIZEM (16 * SIZEBLK + 16) // Pad 256 bytes +//#define SIZEW (MIDDLE * SIZEM + (1 - (MIDDLE & 1)) * 16) // Pad 256 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW (16 * SIZEBLK + 16) // Pad 256 bytes +#define SIZEM (WIDTH / 16 * SIZEW + 0) // Pad 0 bytes +#define SWIZ(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#endif + +// me ranges 0...WIDTH/NW-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0..NW-1 (big multiples of BIG_HEIGHT) +// line ranges 0...BIG_HEIGHT-1 (multiples of one) + +// Read a line for carryFused or FFTW. This line was written by writeMiddleOutLine above. +void OVERLOAD readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + in += (me / 16 * SIZEW) + (middle * SIZEM) + (line % 16 * SIZEBLK) + SWIZ(line % 16, line / 16) * 16 + (me % 16); + for (u32 i = 0; i < NW; ++i) { u[i] = NTLOAD(in[i * G_W / 16 * SIZEW]); } +} + +// Write a line from carryFused. This data will be read by fftMiddleIn. +void OVERLOAD writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + out += (me / 16 * SIZEW) + (middle * SIZEM) + (line % 16 * SIZEBLK) + SWIZ(line % 16, line / 16) * 16 + (me % 16); + for (i32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W / 16 * SIZEW], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleIn +//**************************************************************************************** + +// x ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...SMALL_HEIGHT-1 (multiples of one) + +void OVERLOAD readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { + in += (x / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM]); } +} + +// NOTE: writeMiddleInLine uses the same definition of x,y as readMiddleInLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleInLine. +void OVERLOAD writeMiddleInLine (P(T2) out, T2 *u, u32 y, u32 x) +{ + out += (x / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from tailSquare/Mul +//**************************************************************************************** + +// me ranges 0...SMALL_HEIGHT/NH-1 (multiples of one) +// u[i] ranges 0...NH-1 (big multiples of one) +// line ranges 0...MIDDLE*WIDTH-1 (multiples of SMALL_HEIGHT) + +// Read a line for tailSquare/Mul or fftHin +void OVERLOAD readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + in += (width / 16 * SIZEW) + (middle * SIZEM) + (width % 16 * SIZEBLK) + (me % 16); + for (i32 i = 0; i < NH; ++i) { u[i] = NTLOAD(in[SWIZ(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16]); } +} + +void OVERLOAD writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + out += (width / 16 * SIZEW) + (middle * SIZEM) + (width % 16 * SIZEBLK) + (me % 16); + for (i32 i = 0; i < NH; ++i) { NTSTORE(out[SWIZ(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleOut +//**************************************************************************************** + +// x ranges 0...SMALL_HEIGHT-1 (multiples of one) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) + +void OVERLOAD readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { + in += (y / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM]); } +} + +// NOTE: writeMiddleOutLine uses the same definition of x,y as readMiddleOutLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleOutLine. +void OVERLOAD writeMiddleOutLine (P(T2) out, T2 *u, u32 y, u32 x) +{ + out += (y / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM], u[i]); } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 || NTT_GF31 + +//**************************************************************************************** +// Pair of routines to read/write data to/from carryFused +//**************************************************************************************** + +// NOTE: I hven't studied best padding/swizzling/memory layout for 32-bit values. I'm assuming the 64-bit values scheme will be pretty good. +#if INPLACE == 1 // nVidia friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK32 (SMALL_HEIGHT + 16) // Pad 128 bytes +//#define SIZEW32 (16 * SIZEA + 16) // Pad 128 bytes +//#define SIZEM32 (MIDDLE * SIZEB + (1 - (MIDDLE & 1)) * 16) // Pad 128 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK32 (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW32 (16 * SIZEBLK + 16) // Pad 128 bytes +#define SIZEM32 (WIDTH / 16 * SIZEW + 16) // Pad 128 bytes +#define SWIZ32(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#else // AMD friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK32 (SMALL_HEIGHT + 16) // Pad 128 bytes +//#define SIZEM32 (16 * SIZEBLK + 16) // Pad 128 bytes +//#define SIZEW32 (MIDDLE * SIZEM + (1 - (MIDDLE & 1)) * 16) // Pad 128 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK32 (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW32 (16 * SIZEBLK + 16) // Pad 128 bytes +#define SIZEM32 (WIDTH / 16 * SIZEW + 0) // Pad 0 bytes +#define SWIZ32(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#endif + +// me ranges 0...WIDTH/NW-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0..NW-1 (big multiples of BIG_HEIGHT) +// line ranges 0...BIG_HEIGHT-1 (multiples of one) + +// Read a line for carryFused or FFTW. This line was written by writeMiddleOutLine above. +void OVERLOAD readCarryFusedLine(CP(F2) in, F2 *u, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + in += (me / 16 * SIZEW32) + (middle * SIZEM32) + (line % 16 * SIZEBLK32) + SWIZ32(line % 16, line / 16) * 16 + (me % 16); + for (u32 i = 0; i < NW; ++i) { u[i] = NTLOAD(in[i * G_W / 16 * SIZEW32]); } +} + +// Write a line from carryFused. This data will be read by fftMiddleIn. +void OVERLOAD writeCarryFusedLine(F2 *u, P(F2) out, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + out += (me / 16 * SIZEW32) + (middle * SIZEM32) + (line % 16 * SIZEBLK32) + SWIZ32(line % 16, line / 16) * 16 + (me % 16); + for (i32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W / 16 * SIZEW32], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleIn +//**************************************************************************************** + +// x ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...SMALL_HEIGHT-1 (multiples of one) + +void OVERLOAD readMiddleInLine(F2 *u, CP(F2) in, u32 y, u32 x) { + in += (x / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM32]); } +} + +// NOTE: writeMiddleInLine uses the same definition of x,y as readMiddleInLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleInLine. +void OVERLOAD writeMiddleInLine (P(F2) out, F2 *u, u32 y, u32 x) +{ + out += (x / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM32], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from tailSquare/Mul +//**************************************************************************************** + +// me ranges 0...SMALL_HEIGHT/NH-1 (multiples of one) +// u[i] ranges 0...NH-1 (big multiples of one) +// line ranges 0...MIDDLE*WIDTH-1 (multiples of SMALL_HEIGHT) + +// Read a line for tailSquare/Mul or fftHin +void OVERLOAD readTailFusedLine(CP(F2) in, F2 *u, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + in += (width / 16 * SIZEW32) + (middle * SIZEM32) + (width % 16 * SIZEBLK32) + (me % 16); + for (i32 i = 0; i < NH; ++i) { u[i] = NTLOAD(in[SWIZ32(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16]); } +} + +void OVERLOAD writeTailFusedLine(F2 *u, P(F2) out, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + out += (width / 16 * SIZEW32) + (middle * SIZEM32) + (width % 16 * SIZEBLK32) + (me % 16); + for (i32 i = 0; i < NH; ++i) { NTSTORE(out[SWIZ32(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleOut +//**************************************************************************************** + +// x ranges 0...SMALL_HEIGHT-1 (multiples of one) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) + +void OVERLOAD readMiddleOutLine(F2 *u, CP(F2) in, u32 y, u32 x) { + in += (y / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM32]); } +} + +// NOTE: writeMiddleOutLine uses the same definition of x,y as readMiddleOutLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleOutLine. +void OVERLOAD writeMiddleOutLine (P(F2) out, F2 *u, u32 y, u32 x) +{ + out += (y / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM32], u[i]); } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Since F2 and GF31 are the same size we can simply call the floats based code + +void OVERLOAD readCarryFusedLine(CP(GF31) in, GF31 *u, u32 line) { + readCarryFusedLine((CP(F2)) in, (F2 *) u, line); +} + +void OVERLOAD writeCarryFusedLine(GF31 *u, P(GF31) out, u32 line) { + writeCarryFusedLine((F2 *) u, (P(F2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleInLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF31) out, GF31 *u, u32 y, u32 x) { + writeMiddleInLine ((P(F2)) out, (F2 *) u, y, x); +} + +void OVERLOAD readTailFusedLine(CP(GF31) in, GF31 *u, u32 line, u32 me) { + readTailFusedLine((CP(F2)) in, (F2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF31 *u, P(GF31) out, u32 line, u32 me) { + writeTailFusedLine((F2 *) u, (P(F2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleOutLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF31) out, GF31 *u, u32 y, u32 x) { + writeMiddleOutLine ((P(F2)) out, (F2 *) u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Since T2 and GF61 are the same size we can simply call the doubles based code + +void OVERLOAD readCarryFusedLine(CP(GF61) in, GF61 *u, u32 line) { + readCarryFusedLine((CP(T2)) in, (T2 *) u, line); +} + +void OVERLOAD writeCarryFusedLine(GF61 *u, P(GF61) out, u32 line) { + writeCarryFusedLine((T2 *) u, (P(T2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleInLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF61) out, GF61 *u, u32 y, u32 x) { + writeMiddleInLine ((P(T2)) out, (T2 *) u, y, x); +} + +void OVERLOAD readTailFusedLine(CP(GF61) in, GF61 *u, u32 line, u32 me) { + readTailFusedLine((CP(T2)) in, (T2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF61 *u, P(GF61) out, u32 line, u32 me) { + writeTailFusedLine((T2 *) u, (P(T2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleOutLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF61) out, GF61 *u, u32 y, u32 x) { + writeMiddleOutLine ((P(T2)) out, (T2 *) u, y, x); +} + +#endif + +#endif diff --git a/src/tune.cpp b/src/tune.cpp index 15eedef2..51f892f8 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -439,6 +439,10 @@ void Tune::tune() { u32 variant = (defaultShape == &defaultFFTShape) ? 101 : 202; //GW: if fft spec on the command line specifies a variant then we should use that variant (I get some interesting results with 000 vs 101 vs 201 vs 202 likely due to rocm optimizer) + // IN_WG/SIZEX, OUT_WG/SIZEX, PAD, MIDDLE_IN/OUT_LDS_TRANSPOSE apply only if INPLACE=0 + u32 current_inplace = args->value("INPLACE", 0); + args->flags["INPLACE"] = to_string(0); + // Find best IN_WG,IN_SIZEX,OUT_WG,OUT_SIZEX settings if (1/*option to time IN/OUT settings*/) { FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; @@ -508,6 +512,65 @@ void Tune::tune() { args->flags["PAD"] = to_string(best_pad); } + // Find best MIDDLE_IN_LDS_TRANSPOSE setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_middle_in_lds_transpose = 0; + u32 current_middle_in_lds_transpose = args->value("MIDDLE_IN_LDS_TRANSPOSE", 1); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 middle_in_lds_transpose : {0, 1}) { + args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); + if (middle_in_lds_transpose == current_middle_in_lds_transpose) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } + } + log("Best MIDDLE_IN_LDS_TRANSPOSE is %u. Default MIDDLE_IN_LDS_TRANSPOSE is 1.\n", best_middle_in_lds_transpose); + configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_IN_LDS_TRANSPOSE", best_middle_in_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); + } + + // Find best MIDDLE_OUT_LDS_TRANSPOSE setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_middle_out_lds_transpose = 0; + u32 current_middle_out_lds_transpose = args->value("MIDDLE_OUT_LDS_TRANSPOSE", 1); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 middle_out_lds_transpose : {0, 1}) { + args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); + if (middle_out_lds_transpose == current_middle_out_lds_transpose) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } + } + log("Best MIDDLE_OUT_LDS_TRANSPOSE is %u. Default MIDDLE_OUT_LDS_TRANSPOSE is 1.\n", best_middle_out_lds_transpose); + configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_OUT_LDS_TRANSPOSE", best_middle_out_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); + } + + // Find best INPLACE setting + if (time_FFTs) { + FFTConfig fft{*defaultShape, 101, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_inplace = 0; + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 inplace : {0, 1}) { + args->flags["INPLACE"] = to_string(inplace); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using INPLACE=%u is %6.1f\n", fft.spec().c_str(), inplace, cost); + if (inplace == current_inplace) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_inplace = inplace; } + } + log("Best INPLACE is %u. Default INPLACE is 0. Best INPLACE setting may change when using larger FFTs.\n", best_inplace); + configsUpdate(current_cost, best_cost, 0.002, "INPLACE", best_inplace, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["INPLACE"] = to_string(best_inplace); + } + // Find best NONTEMPORAL setting if (1) { FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; @@ -838,46 +901,6 @@ void Tune::tune() { args->flags["ZEROHACK_H"] = to_string(best_zerohack_h); } - // Find best MIDDLE_IN_LDS_TRANSPOSE setting - if (1) { - FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; - u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_middle_in_lds_transpose = 0; - u32 current_middle_in_lds_transpose = args->value("MIDDLE_IN_LDS_TRANSPOSE", 1); - double best_cost = -1.0; - double current_cost = -1.0; - for (u32 middle_in_lds_transpose : {0, 1}) { - args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); - log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); - if (middle_in_lds_transpose == current_middle_in_lds_transpose) current_cost = cost; - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } - } - log("Best MIDDLE_IN_LDS_TRANSPOSE is %u. Default MIDDLE_IN_LDS_TRANSPOSE is 1.\n", best_middle_in_lds_transpose); - configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_IN_LDS_TRANSPOSE", best_middle_in_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); - args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); - } - - // Find best MIDDLE_OUT_LDS_TRANSPOSE setting - if (1) { - FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; - u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_middle_out_lds_transpose = 0; - u32 current_middle_out_lds_transpose = args->value("MIDDLE_OUT_LDS_TRANSPOSE", 1); - double best_cost = -1.0; - double current_cost = -1.0; - for (u32 middle_out_lds_transpose : {0, 1}) { - args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); - log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); - if (middle_out_lds_transpose == current_middle_out_lds_transpose) current_cost = cost; - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } - } - log("Best MIDDLE_OUT_LDS_TRANSPOSE is %u. Default MIDDLE_OUT_LDS_TRANSPOSE is 1.\n", best_middle_out_lds_transpose); - configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_OUT_LDS_TRANSPOSE", best_middle_out_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); - args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); - } - // Find best BIGLIT setting if (time_FFTs) { FFTConfig fft{*defaultShape, 101, CARRY_AUTO}; @@ -923,7 +946,7 @@ void Tune::tune() { } if (args->workers < 2) { config.write("\n# Running two workers sometimes gives better throughput."); - config.write("\n# Changing TAIL_KERNELS to 1 or 3 with two workers may be better."); + config.write("\n# Changing TAIL_KERNELS to 3 with two workers may be better."); config.write("\n# -workers 2 -use TAIL_KERNELS=3\n"); } } From 46386d037d3eed0364c2ca5ef69718e9bab1435e Mon Sep 17 00:00:00 2001 From: george Date: Wed, 10 Dec 2025 03:18:14 +0000 Subject: [PATCH 113/115] Fixed tuning INPLACE bug --- src/tune.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tune.cpp b/src/tune.cpp index 51f892f8..cc4d70c1 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -553,7 +553,7 @@ void Tune::tune() { } // Find best INPLACE setting - if (time_FFTs) { + if (1) { FFTConfig fft{*defaultShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_inplace = 0; From 4db49dc02db6de89c37c2a77b19556d223d5be76 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 12 Dec 2025 18:33:46 +0000 Subject: [PATCH 114/115] Don't allow testing non-prime exponents. They sometimes raise excessive roundoff errors. --- src/Task.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/Task.cpp b/src/Task.cpp index 676359ec..d72f4db2 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -8,6 +8,7 @@ #include "Worktodo.h" #include "Saver.h" #include "version.h" +#include "Primes.h" #include "Proof.h" #include "log.h" #include "timeutil.h" @@ -211,6 +212,20 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { assert(exponent); + // Testing exponent 140000001 using FFT 512:15:512 fails with severe round off errors. + // I'm guessing this is because bot the exponent and FFT size are divisible by 3. + // Here we make sure the exponent is prime. If not we do not raise an error because it + // is very common to use command line argument "-prp some-random-exponent" to get a quick + // timing. Instead, we output a warning and test a smaller prime exponent. + { + Primes primes; + if (!primes.isPrime(exponent)) { + u32 new_exponent = primes.prevPrime(exponent); + log("Warning: Exponent %u is not prime. Using exponent %u instead.\n", exponent, new_exponent); + exponent = new_exponent; + } + } + LogContext pushContext(std::to_string(exponent)); FFTConfig fft = FFTConfig::bestFit(*shared.args, exponent, shared.args->fftSpec); From 09ed5447e4be79f424c6aa356e475b6c446c3a69 Mon Sep 17 00:00:00 2001 From: george Date: Fri, 12 Dec 2025 18:46:22 +0000 Subject: [PATCH 115/115] Added no fp32 tune option (a Windows user found OpenCL compiler choking on use of fma function on floats). Some minor changes on wording of tune output. --- src/tune.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/tune.cpp b/src/tune.cpp index cc4d70c1..3b773437 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -347,6 +347,7 @@ void Tune::tune() { bool tune_config = 1; bool time_FFTs = 0; bool time_NTTs = 0; + bool time_FP32 = 1; int quick = 7; // Run config from slowest (quick=1) to fastest (quick=10) u64 min_exponent = 75000000; u64 max_exponent = 350000000; @@ -358,6 +359,7 @@ void Tune::tune() { if (s == "noconfig") tune_config = 0; if (s == "fp64") time_FFTs = 1; if (s == "ntt") time_NTTs = 1; + if (s == "nofp32") time_FP32 = 0; auto keyVal = split(s, '='); if (keyVal.size() == 2) { if (keyVal.front() == "quick") quick = stod(keyVal.back()); @@ -554,7 +556,7 @@ void Tune::tune() { // Find best INPLACE setting if (1) { - FFTConfig fft{*defaultShape, 101, CARRY_AUTO}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_inplace = 0; double best_cost = -1.0; @@ -566,7 +568,7 @@ void Tune::tune() { if (inplace == current_inplace) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_inplace = inplace; } } - log("Best INPLACE is %u. Default INPLACE is 0. Best INPLACE setting may change when using larger FFTs.\n", best_inplace); + log("Best INPLACE is %u. Default INPLACE is 0. Best INPLACE setting may be different for other FFT lengths.\n", best_inplace); configsUpdate(current_cost, best_cost, 0.002, "INPLACE", best_inplace, newConfigKeyVals, suggestedConfigKeyVals); args->flags["INPLACE"] = to_string(best_inplace); } @@ -676,7 +678,7 @@ void Tune::tune() { } // Find best TAIL_TRIGS32 setting - if (time_NTTs) { + if (time_NTTs && time_FP32) { FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxBpw() * 0.95 * fft.shape.size()); // Back off the maxExp as different settings will have different maxBpw @@ -759,7 +761,7 @@ void Tune::tune() { } // Find best TABMUL_CHAIN32 setting - if (time_NTTs) { + if (time_NTTs && time_FP32) { FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxBpw() * 0.95 * fft.shape.size()); // Back off the maxExp as different settings will have different maxBpw @@ -945,9 +947,10 @@ void Tune::tune() { config.write("\n -log 1000000\n"); } if (args->workers < 2) { - config.write("\n# Running two workers sometimes gives better throughput."); - config.write("\n# Changing TAIL_KERNELS to 3 with two workers may be better."); - config.write("\n# -workers 2 -use TAIL_KERNELS=3\n"); + config.write("\n# Running two workers sometimes gives better throughput. Autoprimenet will need to create up a second worktodo file."); + config.write("\n# -workers 2\n"); + config.write("\n# Changing TAIL_KERNELS to 3 when running two workers may be better."); + config.write("\n# -use TAIL_KERNELS=3\n"); } } @@ -981,6 +984,7 @@ skip_1K_256 = 0; // Skip some FFTs and NTTs if (shape.fft_type == FFT64 && !time_FFTs) continue; if (shape.fft_type != FFT64 && !time_NTTs) continue; + if ((shape.fft_type == FFT3261 || shape.fft_type == FFT323161 || shape.fft_type == FFT3231 || shape.fft_type == FFT32) && !time_FP32) continue; // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp());